Skip to content

Commit 699dfb0

Browse files
authored
feat: support DoRA LoRA from community (huggingface#7371)
* feat: support dora loras from community * safe-guard dora operations under peft version. * pop use_dora when False * make dora lora from kohya work. * fix: kohya conversion utils. * add a fast test for DoRA compatibility.. * add a nightly test.
1 parent 484c8ef commit 699dfb0

File tree

8 files changed

+138
-5
lines changed

8 files changed

+138
-5
lines changed

src/diffusers/loaders/lora.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
get_adapter_name,
3737
get_peft_kwargs,
3838
is_accelerate_available,
39+
is_peft_version,
3940
is_transformers_available,
4041
logging,
4142
recurse_remove_peft_layers,
@@ -113,7 +114,7 @@ def load_lora_weights(
113114
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
114115
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
115116

116-
is_correct_format = all("lora" in key for key in state_dict.keys())
117+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
117118
if not is_correct_format:
118119
raise ValueError("Invalid LoRA checkpoint.")
119120

@@ -451,6 +452,15 @@ def load_lora_into_unet(
451452
rank[key] = val.shape[1]
452453

453454
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
455+
if "use_dora" in lora_config_kwargs:
456+
if lora_config_kwargs["use_dora"]:
457+
if is_peft_version("<", "0.9.0"):
458+
raise ValueError(
459+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
460+
)
461+
else:
462+
if is_peft_version("<", "0.9.0"):
463+
lora_config_kwargs.pop("use_dora")
454464
lora_config = LoraConfig(**lora_config_kwargs)
455465

456466
# adapter_name
@@ -572,6 +582,15 @@ def load_lora_into_text_encoder(
572582
}
573583

574584
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
585+
if "use_dora" in lora_config_kwargs:
586+
if lora_config_kwargs["use_dora"]:
587+
if is_peft_version("<", "0.9.0"):
588+
raise ValueError(
589+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
590+
)
591+
else:
592+
if is_peft_version("<", "0.9.0"):
593+
lora_config_kwargs.pop("use_dora")
575594
lora_config = LoraConfig(**lora_config_kwargs)
576595

577596
# adapter_name
@@ -654,6 +673,13 @@ def load_lora_into_transformer(
654673
rank[key] = val.shape[1]
655674

656675
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
676+
if "use_dora" in lora_config_kwargs:
677+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
678+
raise ValueError(
679+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
680+
)
681+
else:
682+
lora_config_kwargs.pop("use_dora")
657683
lora_config = LoraConfig(**lora_config_kwargs)
658684

659685
# adapter_name
@@ -1243,7 +1269,7 @@ def load_lora_weights(
12431269
unet_config=self.unet.config,
12441270
**kwargs,
12451271
)
1246-
is_correct_format = all("lora" in key for key in state_dict.keys())
1272+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
12471273
if not is_correct_format:
12481274
raise ValueError("Invalid LoRA checkpoint.")
12491275

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import re
1616

17-
from ..utils import logging
17+
from ..utils import is_peft_version, logging
1818

1919

2020
logger = logging.get_logger(__name__)
@@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
128128
te_state_dict = {}
129129
te2_state_dict = {}
130130
network_alphas = {}
131+
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
132+
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
133+
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
134+
135+
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
136+
if is_peft_version("<", "0.9.0"):
137+
raise ValueError(
138+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139+
)
131140

132141
# every down weight has a corresponding up weight and potentially an alpha weight
133142
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
@@ -198,6 +207,12 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
198207
unet_state_dict[diffusers_name] = state_dict.pop(key)
199208
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
200209

210+
if is_unet_dora_lora:
211+
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212+
unet_state_dict[
213+
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
214+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
215+
201216
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
202217
if lora_name.startswith(("lora_te_", "lora_te1_")):
203218
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
@@ -229,6 +244,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
229244
te2_state_dict[diffusers_name] = state_dict.pop(key)
230245
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
231246

247+
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
248+
dora_scale_key_to_replace_te = (
249+
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
250+
)
251+
if lora_name.startswith(("lora_te_", "lora_te1_")):
252+
te_state_dict[
253+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
254+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
255+
elif lora_name.startswith("lora_te2_"):
256+
te2_state_dict[
257+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
258+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
259+
232260
# Rename the alphas so that they can be mapped appropriately.
233261
if lora_name_alpha in state_dict:
234262
alpha = state_dict.pop(lora_name_alpha).item()

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
is_note_seq_available,
7070
is_onnx_available,
7171
is_peft_available,
72+
is_peft_version,
7273
is_scipy_available,
7374
is_tensorboard_available,
7475
is_torch_available,

src/diffusers/utils/import_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,20 @@ def is_accelerate_version(operation: str, version: str):
628628
return compare_versions(parse(_accelerate_version), operation, version)
629629

630630

631+
def is_peft_version(operation: str, version: str):
632+
"""
633+
Args:
634+
Compares the current PEFT version to a given reference with an operation.
635+
operation (`str`):
636+
A string representation of an operator, such as `">"` or `"<="`
637+
version (`str`):
638+
A version string
639+
"""
640+
if not _peft_version:
641+
return False
642+
return compare_versions(parse(_peft_version), operation, version)
643+
644+
631645
def is_k_diffusion_version(operation: str, version: str):
632646
"""
633647
Args:

src/diffusers/utils/peft_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,15 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
171171

172172
# layer names without the Diffusers specific
173173
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
174+
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
174175

175176
lora_config_kwargs = {
176177
"r": r,
177178
"lora_alpha": lora_alpha,
178179
"rank_pattern": rank_pattern,
179180
"alpha_pattern": alpha_pattern,
180181
"target_modules": target_modules,
182+
"use_dora": use_dora,
181183
}
182184
return lora_config_kwargs
183185

src/diffusers/utils/state_dict_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class StateDictType(enum.Enum):
4747
".to_v_lora.up": ".to_v.lora_B",
4848
".lora.up": ".lora_B",
4949
".lora.down": ".lora_A",
50+
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
5051
}
5152

5253

@@ -104,6 +105,10 @@ class StateDictType(enum.Enum):
104105
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
105106
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
106107
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
108+
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
109+
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
110+
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
111+
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
107112
}
108113

109114
PEFT_TO_KOHYA_SS = {
@@ -315,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
315320
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
316321
elif "unet" in kohya_key:
317322
kohya_key = kohya_key.replace("unet", "lora_unet")
323+
elif "lora_magnitude_vector" in kohya_key:
324+
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
325+
318326
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
319327
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
320328
kohya_ss_state_dict[kohya_key] = weight

tests/lora/test_lora_layers_sdxl.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,3 +630,21 @@ def test_integration_logits_multi_adapter(self):
630630
expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487])
631631
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
632632
assert max_diff < 1e-3
633+
634+
@nightly
635+
def test_integration_logits_for_dora_lora(self):
636+
pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
637+
pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya")
638+
pipeline.enable_model_cpu_offload()
639+
640+
images = pipeline(
641+
"photo of ohwx dog",
642+
num_inference_steps=10,
643+
generator=torch.manual_seed(0),
644+
output_type="np",
645+
).images
646+
647+
predicted_slice = images[0, -3:, -3:, -1].flatten()
648+
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516])
649+
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
650+
assert max_diff < 1e-3

tests/lora/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class PeftLoraLoaderMixinTests:
7272
unet_kwargs = None
7373
vae_kwargs = None
7474

75-
def get_dummy_components(self, scheduler_cls=None):
75+
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
7676
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
7777
rank = 4
7878

@@ -96,10 +96,15 @@ def get_dummy_components(self, scheduler_cls=None):
9696
lora_alpha=rank,
9797
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
9898
init_lora_weights=False,
99+
use_dora=use_dora,
99100
)
100101

101102
unet_lora_config = LoraConfig(
102-
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
103+
r=rank,
104+
lora_alpha=rank,
105+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
106+
init_lora_weights=False,
107+
use_dora=use_dora,
103108
)
104109

105110
if self.has_two_text_encoders:
@@ -1074,6 +1079,37 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self):
10741079
"Fused lora should not change the output",
10751080
)
10761081

1082+
@require_peft_version_greater(peft_version="0.9.0")
1083+
def test_simple_inference_with_dora(self):
1084+
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
1085+
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True)
1086+
pipe = self.pipeline_class(**components)
1087+
pipe = pipe.to(torch_device)
1088+
pipe.set_progress_bar_config(disable=None)
1089+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1090+
1091+
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
1092+
self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3))
1093+
1094+
pipe.text_encoder.add_adapter(text_lora_config)
1095+
pipe.unet.add_adapter(unet_lora_config)
1096+
1097+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1098+
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
1099+
1100+
if self.has_two_text_encoders:
1101+
pipe.text_encoder_2.add_adapter(text_lora_config)
1102+
self.assertTrue(
1103+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
1104+
)
1105+
1106+
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
1107+
1108+
self.assertFalse(
1109+
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
1110+
"DoRA lora should change the output",
1111+
)
1112+
10771113
@unittest.skip("This is failing for now - need to investigate")
10781114
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
10791115
"""

0 commit comments

Comments
 (0)