2525from huggingface_hub import hf_hub_download
2626from torch import nn
2727
28+ from .models .lora import LoRACompatibleConv , LoRACompatibleLinear , LoRAConv2dLayer , LoRALinearLayer
2829from .utils import (
2930 DIFFUSERS_CACHE ,
3031 HF_HUB_OFFLINE ,
5657
5758LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
5859LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
60+ TOTAL_EXAMPLE_KEYS = 5
5961
6062TEXT_INVERSION_NAME = "learned_embeds.bin"
6163TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder):
105107 return attn_modules
106108
107109
110+ def text_encoder_mlp_modules (text_encoder ):
111+ mlp_modules = []
112+
113+ if isinstance (text_encoder , (CLIPTextModel , CLIPTextModelWithProjection )):
114+ for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
115+ mlp_mod = layer .mlp
116+ name = f"text_model.encoder.layers.{ i } .mlp"
117+ mlp_modules .append ((name , mlp_mod ))
118+ else :
119+ raise ValueError (f"do not know how to get mlp modules for: { text_encoder .__class__ .__name__ } " )
120+
121+ return mlp_modules
122+
123+
108124def text_encoder_lora_state_dict (text_encoder ):
109125 state_dict = {}
110126
@@ -304,6 +320,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
304320
305321 # fill attn processors
306322 attn_processors = {}
323+ non_attn_lora_layers = []
307324
308325 is_lora = all ("lora" in k for k in state_dict .keys ())
309326 is_custom_diffusion = any ("custom_diffusion" in k for k in state_dict .keys ())
@@ -327,13 +344,33 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
327344 lora_grouped_dict [attn_processor_key ][sub_key ] = value
328345
329346 for key , value_dict in lora_grouped_dict .items ():
330- rank = value_dict ["to_k_lora.down.weight" ].shape [0 ]
331- hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
332-
333347 attn_processor = self
334348 for sub_key in key .split ("." ):
335349 attn_processor = getattr (attn_processor , sub_key )
336350
351+ # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
352+ # or add_{k,v,q,out_proj}_proj_lora layers.
353+ if "lora.down.weight" in value_dict :
354+ rank = value_dict ["lora.down.weight" ].shape [0 ]
355+ hidden_size = value_dict ["lora.up.weight" ].shape [0 ]
356+
357+ if isinstance (attn_processor , LoRACompatibleConv ):
358+ lora = LoRAConv2dLayer (hidden_size , hidden_size , rank , network_alpha )
359+ elif isinstance (attn_processor , LoRACompatibleLinear ):
360+ lora = LoRALinearLayer (
361+ attn_processor .in_features , attn_processor .out_features , rank , network_alpha
362+ )
363+ else :
364+ raise ValueError (f"Module { key } is not a LoRACompatibleConv or LoRACompatibleLinear module." )
365+
366+ value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
367+ lora .load_state_dict (value_dict )
368+ non_attn_lora_layers .append ((attn_processor , lora ))
369+ continue
370+
371+ rank = value_dict ["to_k_lora.down.weight" ].shape [0 ]
372+ hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
373+
337374 if isinstance (
338375 attn_processor , (AttnAddedKVProcessor , SlicedAttnAddedKVProcessor , AttnAddedKVProcessor2_0 )
339376 ):
@@ -390,10 +427,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
390427
391428 # set correct dtype & device
392429 attn_processors = {k : v .to (device = self .device , dtype = self .dtype ) for k , v in attn_processors .items ()}
430+ non_attn_lora_layers = [(t , l .to (device = self .device , dtype = self .dtype )) for t , l in non_attn_lora_layers ]
393431
394432 # set layers
395433 self .set_attn_processor (attn_processors )
396434
435+ # set ff layers
436+ for target_module , lora_layer in non_attn_lora_layers :
437+ if hasattr (target_module , "set_lora_layer" ):
438+ target_module .set_lora_layer (lora_layer )
439+
397440 def save_attn_procs (
398441 self ,
399442 save_directory : Union [str , os .PathLike ],
@@ -840,7 +883,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
840883 state_dict , network_alpha = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
841884 self .load_lora_into_unet (state_dict , network_alpha = network_alpha , unet = self .unet )
842885 self .load_lora_into_text_encoder (
843- state_dict , network_alpha = network_alpha , text_encoder = self .text_encoder , lora_scale = self .lora_scale
886+ state_dict ,
887+ network_alpha = network_alpha ,
888+ text_encoder = self .text_encoder ,
889+ lora_scale = self .lora_scale ,
844890 )
845891
846892 @classmethod
@@ -1049,6 +1095,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
10491095 text_encoder_lora_state_dict = {
10501096 k .replace (f"{ prefix } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
10511097 }
1098+
10521099 if len (text_encoder_lora_state_dict ) > 0 :
10531100 logger .info (f"Loading { prefix } ." )
10541101
@@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
10921139 rank = text_encoder_lora_state_dict [
10931140 "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
10941141 ].shape [1 ]
1142+ patch_mlp = any (".mlp." in key for key in text_encoder_lora_state_dict .keys ())
10951143
1096- cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank )
1144+ cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank , patch_mlp = patch_mlp )
10971145
10981146 # set correct dtype & device
10991147 text_encoder_lora_state_dict = {
@@ -1125,8 +1173,21 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11251173 attn_module .v_proj = attn_module .v_proj .regular_linear_layer
11261174 attn_module .out_proj = attn_module .out_proj .regular_linear_layer
11271175
1176+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1177+ if isinstance (mlp_module .fc1 , PatchedLoraProjection ):
1178+ mlp_module .fc1 = mlp_module .fc1 .regular_linear_layer
1179+ mlp_module .fc2 = mlp_module .fc2 .regular_linear_layer
1180+
11281181 @classmethod
1129- def _modify_text_encoder (cls , text_encoder , lora_scale = 1 , network_alpha = None , rank = 4 , dtype = None ):
1182+ def _modify_text_encoder (
1183+ cls ,
1184+ text_encoder ,
1185+ lora_scale = 1 ,
1186+ network_alpha = None ,
1187+ rank = 4 ,
1188+ dtype = None ,
1189+ patch_mlp = False ,
1190+ ):
11301191 r"""
11311192 Monkey-patches the forward passes of attention modules of the text encoder.
11321193 """
@@ -1157,6 +1218,18 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
11571218 )
11581219 lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
11591220
1221+ if patch_mlp :
1222+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1223+ mlp_module .fc1 = PatchedLoraProjection (
1224+ mlp_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype
1225+ )
1226+ lora_parameters .extend (mlp_module .fc1 .lora_linear_layer .parameters ())
1227+
1228+ mlp_module .fc2 = PatchedLoraProjection (
1229+ mlp_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype
1230+ )
1231+ lora_parameters .extend (mlp_module .fc2 .lora_linear_layer .parameters ())
1232+
11601233 return lora_parameters
11611234
11621235 @classmethod
@@ -1261,9 +1334,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12611334 unet_state_dict = {}
12621335 te_state_dict = {}
12631336 network_alpha = None
1337+ unloaded_keys = []
12641338
12651339 for key , value in state_dict .items ():
1266- if "lora_down" in key :
1340+ if "hada" in key or "skip" in key :
1341+ unloaded_keys .append (key )
1342+ elif "lora_down" in key :
12671343 lora_name = key .split ("." )[0 ]
12681344 lora_name_up = lora_name + ".lora_up.weight"
12691345 lora_name_alpha = lora_name + ".alpha"
@@ -1284,12 +1360,21 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12841360 diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
12851361 diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
12861362 diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
1363+ diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
1364+ diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
12871365 if "transformer_blocks" in diffusers_name :
12881366 if "attn1" in diffusers_name or "attn2" in diffusers_name :
12891367 diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
12901368 diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
12911369 unet_state_dict [diffusers_name ] = value
12921370 unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1371+ elif "ff" in diffusers_name :
1372+ unet_state_dict [diffusers_name ] = value
1373+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1374+ elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
1375+ unet_state_dict [diffusers_name ] = value
1376+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1377+
12931378 elif lora_name .startswith ("lora_te_" ):
12941379 diffusers_name = key .replace ("lora_te_" , "" ).replace ("_" , "." )
12951380 diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
@@ -1301,6 +1386,19 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13011386 if "self_attn" in diffusers_name :
13021387 te_state_dict [diffusers_name ] = value
13031388 te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1389+ elif "mlp" in diffusers_name :
1390+ # Be aware that this is the new diffusers convention and the rest of the code might
1391+ # not utilize it yet.
1392+ diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
1393+ te_state_dict [diffusers_name ] = value
1394+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1395+
1396+ logger .info ("Kohya-style checkpoint detected." )
1397+ if len (unloaded_keys ) > 0 :
1398+ example_unloaded_keys = ", " .join (x for x in unloaded_keys [:TOTAL_EXAMPLE_KEYS ])
1399+ logger .warning (
1400+ f"There are some keys (such as: { example_unloaded_keys } ) in the checkpoints we don't provide support for."
1401+ )
13041402
13051403 unet_state_dict = {f"{ UNET_NAME } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
13061404 te_state_dict = {f"{ TEXT_ENCODER_NAME } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
@@ -1346,6 +1444,10 @@ def unload_lora_weights(self):
13461444 [attention_proc_class ] = unet_attention_classes
13471445 self .unet .set_attn_processor (regular_attention_classes [attention_proc_class ]())
13481446
1447+ for _ , module in self .unet .named_modules ():
1448+ if hasattr (module , "set_lora_layer" ):
1449+ module .set_lora_layer (None )
1450+
13491451 # Safe to call the following regardless of LoRA.
13501452 self ._remove_text_encoder_monkey_patch ()
13511453
0 commit comments