@@ -798,31 +798,32 @@ def main(args):
798798
799799 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
800800 def save_model_hook (models , weights , output_dir ):
801- # there are only two options here. Either are just the unet attn processor layers
802- # or there are the unet and text encoder atten layers
803- unet_lora_layers_to_save = None
804- text_encoder_one_lora_layers_to_save = None
805- text_encoder_two_lora_layers_to_save = None
806-
807- for model in models :
808- if isinstance (model , type (accelerator .unwrap_model (unet ))):
809- unet_lora_layers_to_save = unet_attn_processors_state_dict (model )
810- elif isinstance (model , type (accelerator .unwrap_model (text_encoder_one ))):
811- text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict (model )
812- elif isinstance (model , type (accelerator .unwrap_model (text_encoder_two ))):
813- text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict (model )
814- else :
815- raise ValueError (f"unexpected save model: { model .__class__ } " )
801+ if accelerator .is_main_process :
802+ # there are only two options here. Either are just the unet attn processor layers
803+ # or there are the unet and text encoder atten layers
804+ unet_lora_layers_to_save = None
805+ text_encoder_one_lora_layers_to_save = None
806+ text_encoder_two_lora_layers_to_save = None
807+
808+ for model in models :
809+ if isinstance (model , type (accelerator .unwrap_model (unet ))):
810+ unet_lora_layers_to_save = unet_attn_processors_state_dict (model )
811+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder_one ))):
812+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict (model )
813+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder_two ))):
814+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict (model )
815+ else :
816+ raise ValueError (f"unexpected save model: { model .__class__ } " )
816817
817- # make sure to pop weight so that corresponding model is not saved again
818- weights .pop ()
818+ # make sure to pop weight so that corresponding model is not saved again
819+ weights .pop ()
819820
820- StableDiffusionXLPipeline .save_lora_weights (
821- output_dir ,
822- unet_lora_layers = unet_lora_layers_to_save ,
823- text_encoder_lora_layers = text_encoder_one_lora_layers_to_save ,
824- text_encoder_2_lora_layers = text_encoder_two_lora_layers_to_save ,
825- )
821+ StableDiffusionXLPipeline .save_lora_weights (
822+ output_dir ,
823+ unet_lora_layers = unet_lora_layers_to_save ,
824+ text_encoder_lora_layers = text_encoder_one_lora_layers_to_save ,
825+ text_encoder_2_lora_layers = text_encoder_two_lora_layers_to_save ,
826+ )
826827
827828 def load_model_hook (models , input_dir ):
828829 unet_ = None
0 commit comments