@@ -83,9 +83,9 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
8383 return text_encoder_lora_layers
8484
8585
86- def set_lora_weights (text_lora_attn_parameters , randn_weight = False ):
86+ def set_lora_weights (lora_attn_parameters , randn_weight = False ):
8787 with torch .no_grad ():
88- for parameter in text_lora_attn_parameters :
88+ for parameter in lora_attn_parameters :
8989 if randn_weight :
9090 parameter [:] = torch .randn_like (parameter )
9191 else :
@@ -155,7 +155,7 @@ def get_dummy_components(self):
155155 }
156156 return pipeline_components , lora_components
157157
158- def get_dummy_inputs (self ):
158+ def get_dummy_inputs (self , with_generator = True ):
159159 batch_size = 1
160160 sequence_length = 10
161161 num_channels = 4
@@ -167,16 +167,16 @@ def get_dummy_inputs(self):
167167
168168 pipeline_inputs = {
169169 "prompt" : "A painting of a squirrel eating a burger" ,
170- "generator" : generator ,
171170 "num_inference_steps" : 2 ,
172171 "guidance_scale" : 6.0 ,
173- "output_type" : "numpy " ,
172+ "output_type" : "np " ,
174173 }
174+ if with_generator :
175+ pipeline_inputs .update ({"generator" : generator })
175176
176177 return noise , input_ids , pipeline_inputs
177178
178- # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
179-
179+ # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
180180 def get_dummy_tokens (self ):
181181 max_seq_length = 77
182182
@@ -399,6 +399,45 @@ def test_lora_unet_attn_processors(self):
399399 )
400400 self .assertIsInstance (module .processor , attn_proc_class )
401401
402+ def test_unload_lora (self ):
403+ pipeline_components , lora_components = self .get_dummy_components ()
404+ _ , _ , pipeline_inputs = self .get_dummy_inputs (with_generator = False )
405+ sd_pipe = StableDiffusionPipeline (** pipeline_components )
406+
407+ original_images = sd_pipe (** pipeline_inputs , generator = torch .manual_seed (0 )).images
408+ orig_image_slice = original_images [0 , - 3 :, - 3 :, - 1 ]
409+
410+ # Emulate training.
411+ set_lora_weights (lora_components ["unet_lora_layers" ].parameters (), randn_weight = True )
412+ set_lora_weights (lora_components ["text_encoder_lora_layers" ].parameters (), randn_weight = True )
413+
414+ with tempfile .TemporaryDirectory () as tmpdirname :
415+ LoraLoaderMixin .save_lora_weights (
416+ save_directory = tmpdirname ,
417+ unet_lora_layers = lora_components ["unet_lora_layers" ],
418+ text_encoder_lora_layers = lora_components ["text_encoder_lora_layers" ],
419+ )
420+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
421+ sd_pipe .load_lora_weights (tmpdirname )
422+
423+ lora_images = sd_pipe (** pipeline_inputs , generator = torch .manual_seed (0 )).images
424+ lora_image_slice = lora_images [0 , - 3 :, - 3 :, - 1 ]
425+
426+ # Unload LoRA parameters.
427+ sd_pipe .unload_lora_weights ()
428+ original_images_two = sd_pipe (** pipeline_inputs , generator = torch .manual_seed (0 )).images
429+ orig_image_slice_two = original_images_two [0 , - 3 :, - 3 :, - 1 ]
430+
431+ assert not np .allclose (
432+ orig_image_slice , lora_image_slice
433+ ), "LoRA parameters should lead to a different image slice."
434+ assert not np .allclose (
435+ orig_image_slice_two , lora_image_slice
436+ ), "LoRA parameters should lead to a different image slice."
437+ assert np .allclose (
438+ orig_image_slice , orig_image_slice_two , atol = 1e-3
439+ ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
440+
402441 @unittest .skipIf (torch_device != "cuda" , "This test is supposed to run on GPU" )
403442 def test_lora_unet_attn_processors_with_xformers (self ):
404443 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -537,3 +576,35 @@ def test_vanilla_funetuning(self):
537576 expected = np .array ([0.7406 , 0.699 , 0.5963 , 0.7493 , 0.7045 , 0.6096 , 0.6886 , 0.6388 , 0.583 ])
538577
539578 self .assertTrue (np .allclose (images , expected , atol = 1e-4 ))
579+
580+ def test_unload_lora (self ):
581+ generator = torch .manual_seed (0 )
582+ prompt = "masterpiece, best quality, mountain"
583+ num_inference_steps = 2
584+
585+ pipe = StableDiffusionPipeline .from_pretrained ("runwayml/stable-diffusion-v1-5" , safety_checker = None ).to (
586+ torch_device
587+ )
588+ initial_images = pipe (
589+ prompt , output_type = "np" , generator = generator , num_inference_steps = num_inference_steps
590+ ).images
591+ initial_images = initial_images [0 , - 3 :, - 3 :, - 1 ].flatten ()
592+
593+ lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
594+ lora_filename = "Colored_Icons_by_vizsumit.safetensors"
595+
596+ pipe .load_lora_weights (lora_model_id , weight_name = lora_filename )
597+ lora_images = pipe (
598+ prompt , output_type = "np" , generator = generator , num_inference_steps = num_inference_steps
599+ ).images
600+ lora_images = lora_images [0 , - 3 :, - 3 :, - 1 ].flatten ()
601+
602+ pipe .unload_lora_weights ()
603+ generator = torch .manual_seed (0 )
604+ unloaded_lora_images = pipe (
605+ prompt , output_type = "np" , generator = generator , num_inference_steps = num_inference_steps
606+ ).images
607+ unloaded_lora_images = unloaded_lora_images [0 , - 3 :, - 3 :, - 1 ].flatten ()
608+
609+ self .assertFalse (np .allclose (initial_images , lora_images ))
610+ self .assertTrue (np .allclose (initial_images , unloaded_lora_images , atol = 1e-3 ))
0 commit comments