Skip to content

Commit 692b7a9

Browse files
authored
[Feat] add: utility for unloading lora. (huggingface#4034)
* add: test for testing unloading lora. * add :reason to skipif. * initial implementation of lora unload(). * apply styling. * add: doc. * change checkpoints. * reinit generator * finalize slow test. * add fast test for unloading lora.
1 parent 71c918b commit 692b7a9

File tree

3 files changed

+116
-7
lines changed

3 files changed

+116
-7
lines changed

docs/source/en/training/lora.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
280280
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
281281
refer to the respective docstrings.
282282

283+
## Unloading LoRA parameters
284+
285+
You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pipeline to unload the LoRA parameters.
286+
283287
## Supporting A1111 themed LoRA checkpoints from Diffusers
284288

285289
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted

src/diffusers/loaders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from .models.attention_processor import (
2626
AttnAddedKVProcessor,
2727
AttnAddedKVProcessor2_0,
28+
AttnProcessor,
29+
AttnProcessor2_0,
2830
CustomDiffusionAttnProcessor,
2931
CustomDiffusionXFormersAttnProcessor,
3032
LoRAAttnAddedKVProcessor,
@@ -1270,6 +1272,38 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12701272
new_state_dict = {**unet_state_dict, **te_state_dict}
12711273
return new_state_dict, network_alpha
12721274

1275+
def unload_lora_weights(self):
1276+
"""
1277+
Unloads the LoRA parameters.
1278+
1279+
Examples:
1280+
1281+
```python
1282+
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
1283+
>>> pipeline.unload_lora_weights()
1284+
>>> ...
1285+
```
1286+
"""
1287+
is_unet_lora = all(
1288+
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
1289+
for _, processor in self.unet.attn_processors.items()
1290+
)
1291+
# Handle attention processors that are a mix of regular attention and AddedKV
1292+
# attention.
1293+
if is_unet_lora:
1294+
is_attn_procs_mixed = all(
1295+
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
1296+
for _, processor in self.unet.attn_processors.items()
1297+
)
1298+
if not is_attn_procs_mixed:
1299+
unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
1300+
self.unet.set_attn_processor(unet_attn_proc_cls())
1301+
else:
1302+
self.unet.set_default_attn_processor()
1303+
1304+
# Safe to call the following regardless of LoRA.
1305+
self._remove_text_encoder_monkey_patch()
1306+
12731307

12741308
class FromSingleFileMixin:
12751309
"""

tests/models/test_lora_layers.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
8383
return text_encoder_lora_layers
8484

8585

86-
def set_lora_weights(text_lora_attn_parameters, randn_weight=False):
86+
def set_lora_weights(lora_attn_parameters, randn_weight=False):
8787
with torch.no_grad():
88-
for parameter in text_lora_attn_parameters:
88+
for parameter in lora_attn_parameters:
8989
if randn_weight:
9090
parameter[:] = torch.randn_like(parameter)
9191
else:
@@ -155,7 +155,7 @@ def get_dummy_components(self):
155155
}
156156
return pipeline_components, lora_components
157157

158-
def get_dummy_inputs(self):
158+
def get_dummy_inputs(self, with_generator=True):
159159
batch_size = 1
160160
sequence_length = 10
161161
num_channels = 4
@@ -167,16 +167,16 @@ def get_dummy_inputs(self):
167167

168168
pipeline_inputs = {
169169
"prompt": "A painting of a squirrel eating a burger",
170-
"generator": generator,
171170
"num_inference_steps": 2,
172171
"guidance_scale": 6.0,
173-
"output_type": "numpy",
172+
"output_type": "np",
174173
}
174+
if with_generator:
175+
pipeline_inputs.update({"generator": generator})
175176

176177
return noise, input_ids, pipeline_inputs
177178

178-
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
179-
179+
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
180180
def get_dummy_tokens(self):
181181
max_seq_length = 77
182182

@@ -399,6 +399,45 @@ def test_lora_unet_attn_processors(self):
399399
)
400400
self.assertIsInstance(module.processor, attn_proc_class)
401401

402+
def test_unload_lora(self):
403+
pipeline_components, lora_components = self.get_dummy_components()
404+
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
405+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
406+
407+
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
408+
orig_image_slice = original_images[0, -3:, -3:, -1]
409+
410+
# Emulate training.
411+
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
412+
set_lora_weights(lora_components["text_encoder_lora_layers"].parameters(), randn_weight=True)
413+
414+
with tempfile.TemporaryDirectory() as tmpdirname:
415+
LoraLoaderMixin.save_lora_weights(
416+
save_directory=tmpdirname,
417+
unet_lora_layers=lora_components["unet_lora_layers"],
418+
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
419+
)
420+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
421+
sd_pipe.load_lora_weights(tmpdirname)
422+
423+
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
424+
lora_image_slice = lora_images[0, -3:, -3:, -1]
425+
426+
# Unload LoRA parameters.
427+
sd_pipe.unload_lora_weights()
428+
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
429+
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
430+
431+
assert not np.allclose(
432+
orig_image_slice, lora_image_slice
433+
), "LoRA parameters should lead to a different image slice."
434+
assert not np.allclose(
435+
orig_image_slice_two, lora_image_slice
436+
), "LoRA parameters should lead to a different image slice."
437+
assert np.allclose(
438+
orig_image_slice, orig_image_slice_two, atol=1e-3
439+
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
440+
402441
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
403442
def test_lora_unet_attn_processors_with_xformers(self):
404443
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -537,3 +576,35 @@ def test_vanilla_funetuning(self):
537576
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
538577

539578
self.assertTrue(np.allclose(images, expected, atol=1e-4))
579+
580+
def test_unload_lora(self):
581+
generator = torch.manual_seed(0)
582+
prompt = "masterpiece, best quality, mountain"
583+
num_inference_steps = 2
584+
585+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
586+
torch_device
587+
)
588+
initial_images = pipe(
589+
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
590+
).images
591+
initial_images = initial_images[0, -3:, -3:, -1].flatten()
592+
593+
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
594+
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
595+
596+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
597+
lora_images = pipe(
598+
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
599+
).images
600+
lora_images = lora_images[0, -3:, -3:, -1].flatten()
601+
602+
pipe.unload_lora_weights()
603+
generator = torch.manual_seed(0)
604+
unloaded_lora_images = pipe(
605+
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
606+
).images
607+
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
608+
609+
self.assertFalse(np.allclose(initial_images, lora_images))
610+
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))

0 commit comments

Comments
 (0)