1313# limitations under the License.
1414import os
1515import re
16- import warnings
1716from collections import defaultdict
1817from contextlib import nullcontext
1918from io import BytesIO
3332 _get_model_file ,
3433 deprecate ,
3534 is_accelerate_available ,
36- is_accelerate_version ,
3735 is_omegaconf_available ,
3836 is_transformers_available ,
3937 logging ,
@@ -308,6 +306,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
308306 # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
309307 # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
310308 network_alphas = kwargs .pop ("network_alphas" , None )
309+
310+ _pipeline = kwargs .pop ("_pipeline" , None )
311+
311312 is_network_alphas_none = network_alphas is None
312313
313314 allow_pickle = False
@@ -461,6 +462,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
461462 load_model_dict_into_meta (lora , value_dict , device = device , dtype = dtype )
462463 else :
463464 lora .load_state_dict (value_dict )
465+
464466 elif is_custom_diffusion :
465467 attn_processors = {}
466468 custom_diffusion_grouped_dict = defaultdict (dict )
@@ -490,19 +492,44 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
490492 cross_attention_dim = cross_attention_dim ,
491493 )
492494 attn_processors [key ].load_state_dict (value_dict )
493-
494- self .set_attn_processor (attn_processors )
495495 else :
496496 raise ValueError (
497497 f"{ model_file } does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
498498 )
499499
500+ # <Unsafe code
501+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
502+ # Now we remove any existing hooks to
503+ is_model_cpu_offload = False
504+ is_sequential_cpu_offload = False
505+ if _pipeline is not None :
506+ for _ , component in _pipeline .components .items ():
507+ if isinstance (component , nn .Module ):
508+ if hasattr (component , "_hf_hook" ):
509+ is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
510+ is_sequential_cpu_offload = isinstance (getattr (component , "_hf_hook" ), AlignDevicesHook )
511+ logger .info (
512+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
513+ )
514+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
515+
516+ # only custom diffusion needs to set attn processors
517+ if is_custom_diffusion :
518+ self .set_attn_processor (attn_processors )
519+
500520 # set lora layers
501521 for target_module , lora_layer in lora_layers_list :
502522 target_module .set_lora_layer (lora_layer )
503523
504524 self .to (dtype = self .dtype , device = self .device )
505525
526+ # Offload back.
527+ if is_model_cpu_offload :
528+ _pipeline .enable_model_cpu_offload ()
529+ elif is_sequential_cpu_offload :
530+ _pipeline .enable_sequential_cpu_offload ()
531+ # Unsafe code />
532+
506533 def convert_state_dict_legacy_attn_format (self , state_dict , network_alphas ):
507534 is_new_lora_format = all (
508535 key .startswith (self .unet_name ) or key .startswith (self .text_encoder_name ) for key in state_dict .keys ()
@@ -1072,41 +1099,31 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10721099 kwargs (`dict`, *optional*):
10731100 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
10741101 """
1075- # Remove any existing hooks.
1076- is_model_cpu_offload = False
1077- is_sequential_cpu_offload = False
1078- recurive = False
1079- for _ , component in self .components .items ():
1080- if isinstance (component , nn .Module ):
1081- if hasattr (component , "_hf_hook" ):
1082- is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
1083- is_sequential_cpu_offload = isinstance (getattr (component , "_hf_hook" ), AlignDevicesHook )
1084- logger .info (
1085- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1086- )
1087- recurive = is_sequential_cpu_offload
1088- remove_hook_from_module (component , recurse = recurive )
1102+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1103+ state_dict , network_alphas = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
1104+
1105+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
1106+ if not is_correct_format :
1107+ raise ValueError ("Invalid LoRA checkpoint." )
10891108
10901109 low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
10911110
1092- state_dict , network_alphas = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
10931111 self .load_lora_into_unet (
1094- state_dict , network_alphas = network_alphas , unet = self .unet , low_cpu_mem_usage = low_cpu_mem_usage
1112+ state_dict ,
1113+ network_alphas = network_alphas ,
1114+ unet = self .unet ,
1115+ low_cpu_mem_usage = low_cpu_mem_usage ,
1116+ _pipeline = self ,
10951117 )
10961118 self .load_lora_into_text_encoder (
10971119 state_dict ,
10981120 network_alphas = network_alphas ,
10991121 text_encoder = self .text_encoder ,
11001122 lora_scale = self .lora_scale ,
11011123 low_cpu_mem_usage = low_cpu_mem_usage ,
1124+ _pipeline = self ,
11021125 )
11031126
1104- # Offload back.
1105- if is_model_cpu_offload :
1106- self .enable_model_cpu_offload ()
1107- elif is_sequential_cpu_offload :
1108- self .enable_sequential_cpu_offload ()
1109-
11101127 @classmethod
11111128 def lora_state_dict (
11121129 cls ,
@@ -1403,7 +1420,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
14031420 return new_state_dict
14041421
14051422 @classmethod
1406- def load_lora_into_unet (cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None ):
1423+ def load_lora_into_unet (cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , _pipeline = None ):
14071424 """
14081425 This will load the LoRA layers specified in `state_dict` into `unet`.
14091426
@@ -1445,13 +1462,22 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage
14451462 # Otherwise, we're dealing with the old format. This means the `state_dict` should only
14461463 # contain the module names of the `unet` as its keys WITHOUT any prefix.
14471464 warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
1448- warnings .warn (warn_message )
1465+ logger .warn (warn_message )
14491466
1450- unet .load_attn_procs (state_dict , network_alphas = network_alphas , low_cpu_mem_usage = low_cpu_mem_usage )
1467+ unet .load_attn_procs (
1468+ state_dict , network_alphas = network_alphas , low_cpu_mem_usage = low_cpu_mem_usage , _pipeline = _pipeline
1469+ )
14511470
14521471 @classmethod
14531472 def load_lora_into_text_encoder (
1454- cls , state_dict , network_alphas , text_encoder , prefix = None , lora_scale = 1.0 , low_cpu_mem_usage = None
1473+ cls ,
1474+ state_dict ,
1475+ network_alphas ,
1476+ text_encoder ,
1477+ prefix = None ,
1478+ lora_scale = 1.0 ,
1479+ low_cpu_mem_usage = None ,
1480+ _pipeline = None ,
14551481 ):
14561482 """
14571483 This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1561,11 +1587,15 @@ def load_lora_into_text_encoder(
15611587 low_cpu_mem_usage = low_cpu_mem_usage ,
15621588 )
15631589
1564- # set correct dtype & device
1565- text_encoder_lora_state_dict = {
1566- k : v .to (device = text_encoder .device , dtype = text_encoder .dtype )
1567- for k , v in text_encoder_lora_state_dict .items ()
1568- }
1590+ is_pipeline_offloaded = _pipeline is not None and any (
1591+ isinstance (c , torch .nn .Module ) and hasattr (c , "_hf_hook" ) for c in _pipeline .components .values ()
1592+ )
1593+ if is_pipeline_offloaded and low_cpu_mem_usage :
1594+ low_cpu_mem_usage = True
1595+ logger .info (
1596+ f"Pipeline { _pipeline .__class__ } is offloaded. Therefore low cpu mem usage loading is forced."
1597+ )
1598+
15691599 if low_cpu_mem_usage :
15701600 device = next (iter (text_encoder_lora_state_dict .values ())).device
15711601 dtype = next (iter (text_encoder_lora_state_dict .values ())).dtype
@@ -1581,8 +1611,33 @@ def load_lora_into_text_encoder(
15811611 f"failed to load text encoder state dict, unexpected keys: { load_state_dict_results .unexpected_keys } "
15821612 )
15831613
1614+ # <Unsafe code
1615+ # We can be sure that the following works as all we do is change the dtype and device of the text encoder
1616+ # Now we remove any existing hooks to
1617+ is_model_cpu_offload = False
1618+ is_sequential_cpu_offload = False
1619+ if _pipeline is not None :
1620+ for _ , component in _pipeline .components .items ():
1621+ if isinstance (component , torch .nn .Module ):
1622+ if hasattr (component , "_hf_hook" ):
1623+ is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
1624+ is_sequential_cpu_offload = isinstance (
1625+ getattr (component , "_hf_hook" ), AlignDevicesHook
1626+ )
1627+ logger .info (
1628+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1629+ )
1630+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
1631+
15841632 text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
15851633
1634+ # Offload back.
1635+ if is_model_cpu_offload :
1636+ _pipeline .enable_model_cpu_offload ()
1637+ elif is_sequential_cpu_offload :
1638+ _pipeline .enable_sequential_cpu_offload ()
1639+ # Unsafe code />
1640+
15861641 @property
15871642 def lora_scale (self ) -> float :
15881643 # property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2652,31 +2707,17 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26522707 # it here explicitly to be able to tell that it's coming from an SDXL
26532708 # pipeline.
26542709
2655- # Remove any existing hooks.
2656- if is_accelerate_available () and is_accelerate_version (">=" , "0.17.0.dev0" ):
2657- from accelerate .hooks import AlignDevicesHook , CpuOffload , remove_hook_from_module
2658- else :
2659- raise ImportError ("Offloading requires `accelerate v0.17.0` or higher." )
2660-
2661- is_model_cpu_offload = False
2662- is_sequential_cpu_offload = False
2663- for _ , component in self .components .items ():
2664- if isinstance (component , torch .nn .Module ):
2665- if hasattr (component , "_hf_hook" ):
2666- is_model_cpu_offload = isinstance (getattr (component , "_hf_hook" ), CpuOffload )
2667- is_sequential_cpu_offload = isinstance (getattr (component , "_hf_hook" ), AlignDevicesHook )
2668- logger .info (
2669- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
2670- )
2671- remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
2672-
2710+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
26732711 state_dict , network_alphas = self .lora_state_dict (
26742712 pretrained_model_name_or_path_or_dict ,
26752713 unet_config = self .unet .config ,
26762714 ** kwargs ,
26772715 )
2678- self .load_lora_into_unet (state_dict , network_alphas = network_alphas , unet = self .unet )
2716+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
2717+ if not is_correct_format :
2718+ raise ValueError ("Invalid LoRA checkpoint." )
26792719
2720+ self .load_lora_into_unet (state_dict , network_alphas = network_alphas , unet = self .unet , _pipeline = self )
26802721 text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
26812722 if len (text_encoder_state_dict ) > 0 :
26822723 self .load_lora_into_text_encoder (
@@ -2685,6 +2726,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26852726 text_encoder = self .text_encoder ,
26862727 prefix = "text_encoder" ,
26872728 lora_scale = self .lora_scale ,
2729+ _pipeline = self ,
26882730 )
26892731
26902732 text_encoder_2_state_dict = {k : v for k , v in state_dict .items () if "text_encoder_2." in k }
@@ -2695,14 +2737,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26952737 text_encoder = self .text_encoder_2 ,
26962738 prefix = "text_encoder_2" ,
26972739 lora_scale = self .lora_scale ,
2740+ _pipeline = self ,
26982741 )
26992742
2700- # Offload back.
2701- if is_model_cpu_offload :
2702- self .enable_model_cpu_offload ()
2703- elif is_sequential_cpu_offload :
2704- self .enable_sequential_cpu_offload ()
2705-
27062743 @classmethod
27072744 def save_lora_weights (
27082745 self ,
0 commit comments