Skip to content

Commit 8009272

Browse files
sayakpaulstevhliupatrickvonplaten
authored
[Tests and Docs] Add a test on serializing pipelines with components containing fused LoRA modules (huggingface#4962)
* add: test to ensure pipelines can be saved with fused lora modules. * add docs about serialization with fused lora. * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * Empty-Commit * Update docs/source/en/training/lora.md Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 1037287 commit 8009272

File tree

2 files changed

+85
-11
lines changed

2 files changed

+85
-11
lines changed

docs/source/en/training/lora.md

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ the attention layers of a language model is sufficient to obtain good downstream
3434

3535
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. 🧨 Diffusers now supports finetuning with LoRA for [text-to-image generation](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) and [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora). This guide will show you how to do both.
3636

37-
If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](hf.co/join) if you don't have one already):
37+
If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](https://hf.co/join) if you don't have one already):
3838

3939
```bash
4040
huggingface-cli login
@@ -321,7 +321,7 @@ pipe.fuse_lora()
321321

322322
generator = torch.manual_seed(0)
323323
images_fusion = pipe(
324-
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
324+
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=2
325325
).images
326326

327327
# To work with a different `lora_scale`, first reverse the effects of `fuse_lora()`.
@@ -333,7 +333,48 @@ pipe.fuse_lora(lora_scale=0.5)
333333

334334
generator = torch.manual_seed(0)
335335
images_fusion = pipe(
336-
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
336+
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=2
337+
).images
338+
```
339+
340+
## Serializing pipelines with fused LoRA parameters
341+
342+
Let's say you want to load the pipeline above that has its UNet fused with the LoRA parameters. You can easily do so by simply calling the `save_pretrained()` method on `pipe`.
343+
344+
After loading the LoRA parameters into a pipeline, if you want to serialize the pipeline such that the affected model components are already fused with the LoRA parameters, you should:
345+
346+
* call `fuse_lora()` on the pipeline with the desired `lora_scale`, given you've already loaded the LoRA parameters into it.
347+
* call `save_pretrained()` on the pipeline.
348+
349+
Here is a complete example:
350+
351+
```python
352+
from diffusers import DiffusionPipeline
353+
import torch
354+
355+
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
356+
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
357+
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
358+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
359+
360+
# First, fuse the LoRA parameters.
361+
pipe.fuse_lora()
362+
363+
# Then save.
364+
pipe.save_pretrained("my-pipeline-with-fused-lora")
365+
```
366+
367+
Now, you can load the pipeline and directly perform inference without having to load the LoRA parameters again:
368+
369+
```python
370+
from diffusers import DiffusionPipeline
371+
import torch
372+
373+
pipe = DiffusionPipeline.from_pretrained("my-pipeline-with-fused-lora", torch_dtype=torch.float16).to("cuda")
374+
375+
generator = torch.manual_seed(0)
376+
images_fusion = pipe(
377+
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=2
337378
).images
338379
```
339380

tests/models/test_lora_layers.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -965,15 +965,11 @@ def test_with_different_scales_fusion_equivalence(self):
965965
pipeline_components, lora_components = self.get_dummy_components()
966966
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
967967
sd_pipe = sd_pipe.to(torch_device)
968-
# sd_pipe.unet.set_default_attn_processor()
969968
sd_pipe.set_progress_bar_config(disable=None)
970969

971970
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
972971

973-
images = sd_pipe(
974-
**pipeline_inputs,
975-
generator=torch.manual_seed(0),
976-
).images
972+
images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
977973
images_slice = images[0, -3:, -3:, -1]
978974

979975
# Emulate training.
@@ -993,9 +989,7 @@ def test_with_different_scales_fusion_equivalence(self):
993989
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
994990

995991
lora_images_scale_0_5 = sd_pipe(
996-
**pipeline_inputs,
997-
generator=torch.manual_seed(0),
998-
cross_attention_kwargs={"scale": 0.5},
992+
**pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
999993
).images
1000994
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
1001995

@@ -1017,6 +1011,45 @@ def test_with_different_scales_fusion_equivalence(self):
10171011
images_slice, lora_image_slice_scale_0_5, atol=1e-03
10181012
), "0.5 scale and no scale shouldn't match"
10191013

1014+
def test_save_load_fused_lora_modules(self):
1015+
pipeline_components, lora_components = self.get_dummy_components()
1016+
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
1017+
sd_pipe = sd_pipe.to(torch_device)
1018+
sd_pipe.set_progress_bar_config(disable=None)
1019+
1020+
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
1021+
1022+
# Emulate training.
1023+
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1)
1024+
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1)
1025+
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1)
1026+
1027+
with tempfile.TemporaryDirectory() as tmpdirname:
1028+
StableDiffusionXLPipeline.save_lora_weights(
1029+
save_directory=tmpdirname,
1030+
unet_lora_layers=lora_components["unet_lora_layers"],
1031+
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
1032+
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
1033+
safe_serialization=True,
1034+
)
1035+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
1036+
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
1037+
1038+
sd_pipe.fuse_lora()
1039+
lora_images_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
1040+
lora_image_slice_fusion = lora_images_fusion[0, -3:, -3:, -1]
1041+
1042+
with tempfile.TemporaryDirectory() as tmpdirname:
1043+
sd_pipe.save_pretrained(tmpdirname)
1044+
sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname)
1045+
1046+
loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images
1047+
loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1]
1048+
1049+
assert np.allclose(
1050+
lora_image_slice_fusion, loaded_lora_image_slice, atol=1e-03
1051+
), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth."
1052+
10201053

10211054
@slow
10221055
@require_torch_gpu

0 commit comments

Comments
 (0)