@@ -679,6 +679,43 @@ def test_unload_lora_sdxl(self):
679679 orig_image_slice , orig_image_slice_two , atol = 1e-3
680680 ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
681681
682+ def test_load_lora_locally (self ):
683+ pipeline_components , lora_components = self .get_dummy_components ()
684+ sd_pipe = StableDiffusionXLPipeline (** pipeline_components )
685+ sd_pipe = sd_pipe .to (torch_device )
686+ sd_pipe .set_progress_bar_config (disable = None )
687+
688+ with tempfile .TemporaryDirectory () as tmpdirname :
689+ StableDiffusionXLPipeline .save_lora_weights (
690+ save_directory = tmpdirname ,
691+ unet_lora_layers = lora_components ["unet_lora_layers" ],
692+ text_encoder_lora_layers = lora_components ["text_encoder_one_lora_layers" ],
693+ text_encoder_2_lora_layers = lora_components ["text_encoder_two_lora_layers" ],
694+ )
695+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
696+ sd_pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
697+
698+ sd_pipe .unload_lora_weights ()
699+
700+ def test_load_lora_locally_safetensors (self ):
701+ pipeline_components , lora_components = self .get_dummy_components ()
702+ sd_pipe = StableDiffusionXLPipeline (** pipeline_components )
703+ sd_pipe = sd_pipe .to (torch_device )
704+ sd_pipe .set_progress_bar_config (disable = None )
705+
706+ with tempfile .TemporaryDirectory () as tmpdirname :
707+ StableDiffusionXLPipeline .save_lora_weights (
708+ save_directory = tmpdirname ,
709+ unet_lora_layers = lora_components ["unet_lora_layers" ],
710+ text_encoder_lora_layers = lora_components ["text_encoder_one_lora_layers" ],
711+ text_encoder_2_lora_layers = lora_components ["text_encoder_two_lora_layers" ],
712+ safe_serialization = True ,
713+ )
714+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
715+ sd_pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ))
716+
717+ sd_pipe .unload_lora_weights ()
718+
682719
683720@slow
684721@require_torch_gpu
0 commit comments