5656from diffusers .training_utils import compute_snr
5757from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
5858from diffusers .utils .import_utils import is_xformers_available
59+ from diffusers .utils .torch_utils import is_compiled_module
5960
6061
6162# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1007,6 +1008,11 @@ def main(args):
10071008 if param .requires_grad :
10081009 param .data = param .to (torch .float32 )
10091010
1011+ def unwrap_model (model ):
1012+ model = accelerator .unwrap_model (model )
1013+ model = model ._orig_mod if is_compiled_module (model ) else model
1014+ return model
1015+
10101016 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
10111017 def save_model_hook (models , weights , output_dir ):
10121018 if accelerator .is_main_process :
@@ -1017,13 +1023,13 @@ def save_model_hook(models, weights, output_dir):
10171023 text_encoder_two_lora_layers_to_save = None
10181024
10191025 for model in models :
1020- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1026+ if isinstance (model , type (unwrap_model (unet ))):
10211027 unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
1022- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1028+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
10231029 text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
10241030 get_peft_model_state_dict (model )
10251031 )
1026- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1032+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
10271033 text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
10281034 get_peft_model_state_dict (model )
10291035 )
@@ -1048,11 +1054,11 @@ def load_model_hook(models, input_dir):
10481054 while len (models ) > 0 :
10491055 model = models .pop ()
10501056
1051- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1057+ if isinstance (model , type (unwrap_model (unet ))):
10521058 unet_ = model
1053- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1059+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
10541060 text_encoder_one_ = model
1055- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1061+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
10561062 text_encoder_two_ = model
10571063 else :
10581064 raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -1621,16 +1627,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16211627 # Save the lora layers
16221628 accelerator .wait_for_everyone ()
16231629 if accelerator .is_main_process :
1624- unet = accelerator . unwrap_model (unet )
1630+ unet = unwrap_model (unet )
16251631 unet = unet .to (torch .float32 )
16261632 unet_lora_layers = convert_state_dict_to_diffusers (get_peft_model_state_dict (unet ))
16271633
16281634 if args .train_text_encoder :
1629- text_encoder_one = accelerator . unwrap_model (text_encoder_one )
1635+ text_encoder_one = unwrap_model (text_encoder_one )
16301636 text_encoder_lora_layers = convert_state_dict_to_diffusers (
16311637 get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
16321638 )
1633- text_encoder_two = accelerator . unwrap_model (text_encoder_two )
1639+ text_encoder_two = unwrap_model (text_encoder_two )
16341640 text_encoder_2_lora_layers = convert_state_dict_to_diffusers (
16351641 get_peft_model_state_dict (text_encoder_two .to (torch .float32 ))
16361642 )
0 commit comments