Skip to content

Commit ed50768

Browse files
[LoRA] don't break offloading for incompatible lora ckpts. (huggingface#5085)
* don't break offloading for incompatible lora ckpts. * debugging * better condition. * fix * fix * fix * fix --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7974fad commit ed50768

File tree

1 file changed

+97
-60
lines changed

1 file changed

+97
-60
lines changed

src/diffusers/loaders.py

Lines changed: 97 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import os
1515
import re
16-
import warnings
1716
from collections import defaultdict
1817
from contextlib import nullcontext
1918
from io import BytesIO
@@ -33,7 +32,6 @@
3332
_get_model_file,
3433
deprecate,
3534
is_accelerate_available,
36-
is_accelerate_version,
3735
is_omegaconf_available,
3836
is_transformers_available,
3937
logging,
@@ -308,6 +306,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
308306
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
309307
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
310308
network_alphas = kwargs.pop("network_alphas", None)
309+
310+
_pipeline = kwargs.pop("_pipeline", None)
311+
311312
is_network_alphas_none = network_alphas is None
312313

313314
allow_pickle = False
@@ -461,6 +462,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
461462
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
462463
else:
463464
lora.load_state_dict(value_dict)
465+
464466
elif is_custom_diffusion:
465467
attn_processors = {}
466468
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -490,19 +492,44 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
490492
cross_attention_dim=cross_attention_dim,
491493
)
492494
attn_processors[key].load_state_dict(value_dict)
493-
494-
self.set_attn_processor(attn_processors)
495495
else:
496496
raise ValueError(
497497
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
498498
)
499499

500+
# <Unsafe code
501+
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
502+
# Now we remove any existing hooks to
503+
is_model_cpu_offload = False
504+
is_sequential_cpu_offload = False
505+
if _pipeline is not None:
506+
for _, component in _pipeline.components.items():
507+
if isinstance(component, nn.Module):
508+
if hasattr(component, "_hf_hook"):
509+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
510+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
511+
logger.info(
512+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
513+
)
514+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
515+
516+
# only custom diffusion needs to set attn processors
517+
if is_custom_diffusion:
518+
self.set_attn_processor(attn_processors)
519+
500520
# set lora layers
501521
for target_module, lora_layer in lora_layers_list:
502522
target_module.set_lora_layer(lora_layer)
503523

504524
self.to(dtype=self.dtype, device=self.device)
505525

526+
# Offload back.
527+
if is_model_cpu_offload:
528+
_pipeline.enable_model_cpu_offload()
529+
elif is_sequential_cpu_offload:
530+
_pipeline.enable_sequential_cpu_offload()
531+
# Unsafe code />
532+
506533
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
507534
is_new_lora_format = all(
508535
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -1072,41 +1099,31 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10721099
kwargs (`dict`, *optional*):
10731100
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
10741101
"""
1075-
# Remove any existing hooks.
1076-
is_model_cpu_offload = False
1077-
is_sequential_cpu_offload = False
1078-
recurive = False
1079-
for _, component in self.components.items():
1080-
if isinstance(component, nn.Module):
1081-
if hasattr(component, "_hf_hook"):
1082-
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1083-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1084-
logger.info(
1085-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1086-
)
1087-
recurive = is_sequential_cpu_offload
1088-
remove_hook_from_module(component, recurse=recurive)
1102+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1103+
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1104+
1105+
is_correct_format = all("lora" in key for key in state_dict.keys())
1106+
if not is_correct_format:
1107+
raise ValueError("Invalid LoRA checkpoint.")
10891108

10901109
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
10911110

1092-
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
10931111
self.load_lora_into_unet(
1094-
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
1112+
state_dict,
1113+
network_alphas=network_alphas,
1114+
unet=self.unet,
1115+
low_cpu_mem_usage=low_cpu_mem_usage,
1116+
_pipeline=self,
10951117
)
10961118
self.load_lora_into_text_encoder(
10971119
state_dict,
10981120
network_alphas=network_alphas,
10991121
text_encoder=self.text_encoder,
11001122
lora_scale=self.lora_scale,
11011123
low_cpu_mem_usage=low_cpu_mem_usage,
1124+
_pipeline=self,
11021125
)
11031126

1104-
# Offload back.
1105-
if is_model_cpu_offload:
1106-
self.enable_model_cpu_offload()
1107-
elif is_sequential_cpu_offload:
1108-
self.enable_sequential_cpu_offload()
1109-
11101127
@classmethod
11111128
def lora_state_dict(
11121129
cls,
@@ -1403,7 +1420,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
14031420
return new_state_dict
14041421

14051422
@classmethod
1406-
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
1423+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
14071424
"""
14081425
This will load the LoRA layers specified in `state_dict` into `unet`.
14091426
@@ -1445,13 +1462,22 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage
14451462
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
14461463
# contain the module names of the `unet` as its keys WITHOUT any prefix.
14471464
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
1448-
warnings.warn(warn_message)
1465+
logger.warn(warn_message)
14491466

1450-
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
1467+
unet.load_attn_procs(
1468+
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
1469+
)
14511470

14521471
@classmethod
14531472
def load_lora_into_text_encoder(
1454-
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
1473+
cls,
1474+
state_dict,
1475+
network_alphas,
1476+
text_encoder,
1477+
prefix=None,
1478+
lora_scale=1.0,
1479+
low_cpu_mem_usage=None,
1480+
_pipeline=None,
14551481
):
14561482
"""
14571483
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1561,11 +1587,15 @@ def load_lora_into_text_encoder(
15611587
low_cpu_mem_usage=low_cpu_mem_usage,
15621588
)
15631589

1564-
# set correct dtype & device
1565-
text_encoder_lora_state_dict = {
1566-
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
1567-
for k, v in text_encoder_lora_state_dict.items()
1568-
}
1590+
is_pipeline_offloaded = _pipeline is not None and any(
1591+
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
1592+
)
1593+
if is_pipeline_offloaded and low_cpu_mem_usage:
1594+
low_cpu_mem_usage = True
1595+
logger.info(
1596+
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
1597+
)
1598+
15691599
if low_cpu_mem_usage:
15701600
device = next(iter(text_encoder_lora_state_dict.values())).device
15711601
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
@@ -1581,8 +1611,33 @@ def load_lora_into_text_encoder(
15811611
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
15821612
)
15831613

1614+
# <Unsafe code
1615+
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
1616+
# Now we remove any existing hooks to
1617+
is_model_cpu_offload = False
1618+
is_sequential_cpu_offload = False
1619+
if _pipeline is not None:
1620+
for _, component in _pipeline.components.items():
1621+
if isinstance(component, torch.nn.Module):
1622+
if hasattr(component, "_hf_hook"):
1623+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1624+
is_sequential_cpu_offload = isinstance(
1625+
getattr(component, "_hf_hook"), AlignDevicesHook
1626+
)
1627+
logger.info(
1628+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1629+
)
1630+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
1631+
15841632
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
15851633

1634+
# Offload back.
1635+
if is_model_cpu_offload:
1636+
_pipeline.enable_model_cpu_offload()
1637+
elif is_sequential_cpu_offload:
1638+
_pipeline.enable_sequential_cpu_offload()
1639+
# Unsafe code />
1640+
15861641
@property
15871642
def lora_scale(self) -> float:
15881643
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2652,31 +2707,17 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26522707
# it here explicitly to be able to tell that it's coming from an SDXL
26532708
# pipeline.
26542709

2655-
# Remove any existing hooks.
2656-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
2657-
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
2658-
else:
2659-
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
2660-
2661-
is_model_cpu_offload = False
2662-
is_sequential_cpu_offload = False
2663-
for _, component in self.components.items():
2664-
if isinstance(component, torch.nn.Module):
2665-
if hasattr(component, "_hf_hook"):
2666-
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
2667-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
2668-
logger.info(
2669-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
2670-
)
2671-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
2672-
2710+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
26732711
state_dict, network_alphas = self.lora_state_dict(
26742712
pretrained_model_name_or_path_or_dict,
26752713
unet_config=self.unet.config,
26762714
**kwargs,
26772715
)
2678-
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
2716+
is_correct_format = all("lora" in key for key in state_dict.keys())
2717+
if not is_correct_format:
2718+
raise ValueError("Invalid LoRA checkpoint.")
26792719

2720+
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
26802721
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
26812722
if len(text_encoder_state_dict) > 0:
26822723
self.load_lora_into_text_encoder(
@@ -2685,6 +2726,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26852726
text_encoder=self.text_encoder,
26862727
prefix="text_encoder",
26872728
lora_scale=self.lora_scale,
2729+
_pipeline=self,
26882730
)
26892731

26902732
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -2695,14 +2737,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26952737
text_encoder=self.text_encoder_2,
26962738
prefix="text_encoder_2",
26972739
lora_scale=self.lora_scale,
2740+
_pipeline=self,
26982741
)
26992742

2700-
# Offload back.
2701-
if is_model_cpu_offload:
2702-
self.enable_model_cpu_offload()
2703-
elif is_sequential_cpu_offload:
2704-
self.enable_sequential_cpu_offload()
2705-
27062743
@classmethod
27072744
def save_lora_weights(
27082745
self,

0 commit comments

Comments
 (0)