Skip to content

Commit 8a69273

Browse files
FIX [PEFT / Core] Copy the state dict when passing it to load_lora_weights (huggingface#7058)
* copy the state dict in load lora weights * fixup
1 parent 5aa31bd commit 8a69273

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/diffusers/loaders/lora.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def load_lora_weights(
106106
if not USE_PEFT_BACKEND:
107107
raise ValueError("PEFT backend is required for this method.")
108108

109+
# if a dict is passed, copy it instead of modifying it inplace
110+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
111+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
112+
109113
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
110114
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
111115

@@ -1229,6 +1233,10 @@ def load_lora_weights(
12291233
# it here explicitly to be able to tell that it's coming from an SDXL
12301234
# pipeline.
12311235

1236+
# if a dict is passed, copy it instead of modifying it inplace
1237+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
1238+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1239+
12321240
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
12331241
state_dict, network_alphas = self.lora_state_dict(
12341242
pretrained_model_name_or_path_or_dict,

tests/lora/test_lora_layers_peft.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
from huggingface_hub import hf_hub_download
2727
from huggingface_hub.repocard import RepoCard
2828
from packaging import version
29+
from safetensors.torch import load_file
2930
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
3031

3132
from diffusers import (
3233
AutoencoderKL,
3334
AutoPipelineForImage2Image,
35+
AutoPipelineForText2Image,
3436
ControlNetModel,
3537
DDIMScheduler,
3638
DiffusionPipeline,
@@ -1745,6 +1747,40 @@ def test_load_unload_load_kohya_lora(self):
17451747
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
17461748
release_memory(pipe)
17471749

1750+
def test_not_empty_state_dict(self):
1751+
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
1752+
pipe = AutoPipelineForText2Image.from_pretrained(
1753+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
1754+
).to("cuda")
1755+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
1756+
1757+
cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
1758+
lcm_lora = load_file(cached_file)
1759+
1760+
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
1761+
self.assertTrue(lcm_lora != {})
1762+
release_memory(pipe)
1763+
1764+
def test_load_unload_load_state_dict(self):
1765+
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
1766+
pipe = AutoPipelineForText2Image.from_pretrained(
1767+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
1768+
).to("cuda")
1769+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
1770+
1771+
cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
1772+
lcm_lora = load_file(cached_file)
1773+
previous_state_dict = lcm_lora.copy()
1774+
1775+
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
1776+
self.assertDictEqual(lcm_lora, previous_state_dict)
1777+
1778+
pipe.unload_lora_weights()
1779+
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
1780+
self.assertDictEqual(lcm_lora, previous_state_dict)
1781+
1782+
release_memory(pipe)
1783+
17481784

17491785
@slow
17501786
@require_torch_gpu

0 commit comments

Comments
 (0)