Skip to content

[lora] fix: lora unloading behvaiour #11822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,8 @@ def unload_lora(self):
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
if hasattr(self, "_hf_peft_config_loaded"):
self._hf_peft_config_loaded = None

_maybe_remove_and_reapply_group_offloading(self)

Expand Down
65 changes: 41 additions & 24 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):

return modules_to_save

def check_if_adapters_added_correctly(
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
):
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
Expand Down Expand Up @@ -345,7 +343,7 @@ def test_simple_inference_with_text_lora(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
Expand Down Expand Up @@ -428,7 +426,7 @@ def test_low_cpu_mem_usage_with_loading(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -484,7 +482,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
Expand Down Expand Up @@ -522,7 +520,7 @@ def test_simple_inference_with_text_lora_fused(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

pipe.fuse_lora()
# Fusing should still keep the LoRA layers
Expand Down Expand Up @@ -554,7 +552,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

pipe.unload_lora_weights()
# unloading should remove the LoRA layers
Expand Down Expand Up @@ -589,7 +587,7 @@ def test_simple_inference_with_text_lora_save_load(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -640,7 +638,7 @@ def test_simple_inference_with_partial_text_lora(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -691,7 +689,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down Expand Up @@ -734,7 +732,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -775,7 +773,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
Expand Down Expand Up @@ -819,7 +817,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

Expand Down Expand Up @@ -857,7 +855,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.unload_lora_weights()
# unloading should remove the LoRA layers
Expand Down Expand Up @@ -893,7 +891,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
Expand Down Expand Up @@ -1010,7 +1008,7 @@ def test_wrong_adapter_name_raises_error(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)

Expand All @@ -1032,7 +1030,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)

Expand Down Expand Up @@ -1759,7 +1757,7 @@ def test_simple_inference_with_dora(self):
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -1850,7 +1848,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
Expand Down Expand Up @@ -1937,7 +1935,7 @@ def test_set_adapters_match_attention_kwargs(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
Expand Down Expand Up @@ -2119,7 +2117,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
Expand Down Expand Up @@ -2237,7 +2235,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
)
pipe = self.pipeline_class(**components)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)

Expand Down Expand Up @@ -2290,7 +2288,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand All @@ -2309,6 +2307,25 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)

def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]

# unload and then add.
pipe.unload_lora_weights()
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]

def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
Expand Down