@@ -1245,6 +1245,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
12451245
12461246 if len (text_encoder_lora_state_dict ) > 0 :
12471247 logger .info (f"Loading { prefix } ." )
1248+ rank = {}
12481249
12491250 if any ("to_out_lora" in k for k in text_encoder_lora_state_dict .keys ()):
12501251 # Convert from the old naming convention to the new naming convention.
@@ -1283,10 +1284,17 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
12831284 f"{ name } .out_proj.lora_linear_layer.down.weight"
12841285 ] = text_encoder_lora_state_dict .pop (f"{ name } .to_out_lora.down.weight" )
12851286
1286- rank = text_encoder_lora_state_dict [
1287- "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
1288- ].shape [1 ]
1287+ for name , _ in text_encoder_attn_modules (text_encoder ):
1288+ rank_key = f"{ name } .out_proj.lora_linear_layer.up.weight"
1289+ rank .update ({rank_key : text_encoder_lora_state_dict [rank_key ].shape [1 ]})
1290+
12891291 patch_mlp = any (".mlp." in key for key in text_encoder_lora_state_dict .keys ())
1292+ if patch_mlp :
1293+ for name , _ in text_encoder_mlp_modules (text_encoder ):
1294+ rank_key_fc1 = f"{ name } .fc1.lora_linear_layer.up.weight"
1295+ rank_key_fc2 = f"{ name } .fc2.lora_linear_layer.up.weight"
1296+ rank .update ({rank_key_fc1 : text_encoder_lora_state_dict [rank_key_fc1 ].shape [1 ]})
1297+ rank .update ({rank_key_fc2 : text_encoder_lora_state_dict [rank_key_fc2 ].shape [1 ]})
12901298
12911299 if network_alphas is not None :
12921300 alpha_keys = [
@@ -1344,7 +1352,7 @@ def _modify_text_encoder(
13441352 text_encoder ,
13451353 lora_scale = 1 ,
13461354 network_alphas = None ,
1347- rank = 4 ,
1355+ rank : Union [ Dict [ str , int ], int ] = 4 ,
13481356 dtype = None ,
13491357 patch_mlp = False ,
13501358 ):
@@ -1365,38 +1373,45 @@ def _modify_text_encoder(
13651373 value_alpha = network_alphas .pop (name + ".to_v_lora.down.weight.alpha" , None )
13661374 out_alpha = network_alphas .pop (name + ".to_out_lora.down.weight.alpha" , None )
13671375
1376+ if isinstance (rank , dict ):
1377+ current_rank = rank .pop (f"{ name } .out_proj.lora_linear_layer.up.weight" )
1378+ else :
1379+ current_rank = rank
1380+
13681381 attn_module .q_proj = PatchedLoraProjection (
1369- attn_module .q_proj , lora_scale , network_alpha = query_alpha , rank = rank , dtype = dtype
1382+ attn_module .q_proj , lora_scale , network_alpha = query_alpha , rank = current_rank , dtype = dtype
13701383 )
13711384 lora_parameters .extend (attn_module .q_proj .lora_linear_layer .parameters ())
13721385
13731386 attn_module .k_proj = PatchedLoraProjection (
1374- attn_module .k_proj , lora_scale , network_alpha = key_alpha , rank = rank , dtype = dtype
1387+ attn_module .k_proj , lora_scale , network_alpha = key_alpha , rank = current_rank , dtype = dtype
13751388 )
13761389 lora_parameters .extend (attn_module .k_proj .lora_linear_layer .parameters ())
13771390
13781391 attn_module .v_proj = PatchedLoraProjection (
1379- attn_module .v_proj , lora_scale , network_alpha = value_alpha , rank = rank , dtype = dtype
1392+ attn_module .v_proj , lora_scale , network_alpha = value_alpha , rank = current_rank , dtype = dtype
13801393 )
13811394 lora_parameters .extend (attn_module .v_proj .lora_linear_layer .parameters ())
13821395
13831396 attn_module .out_proj = PatchedLoraProjection (
1384- attn_module .out_proj , lora_scale , network_alpha = out_alpha , rank = rank , dtype = dtype
1397+ attn_module .out_proj , lora_scale , network_alpha = out_alpha , rank = current_rank , dtype = dtype
13851398 )
13861399 lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
13871400
13881401 if patch_mlp :
13891402 for name , mlp_module in text_encoder_mlp_modules (text_encoder ):
13901403 fc1_alpha = network_alphas .pop (name + ".fc1.lora_linear_layer.down.weight.alpha" )
13911404 fc2_alpha = network_alphas .pop (name + ".fc2.lora_linear_layer.down.weight.alpha" )
1405+ current_rank_fc1 = rank .pop (f"{ name } .fc1.lora_linear_layer.up.weight" )
1406+ current_rank_fc2 = rank .pop (f"{ name } .fc2.lora_linear_layer.up.weight" )
13921407
13931408 mlp_module .fc1 = PatchedLoraProjection (
1394- mlp_module .fc1 , lora_scale , network_alpha = fc1_alpha , rank = rank , dtype = dtype
1409+ mlp_module .fc1 , lora_scale , network_alpha = fc1_alpha , rank = current_rank_fc1 , dtype = dtype
13951410 )
13961411 lora_parameters .extend (mlp_module .fc1 .lora_linear_layer .parameters ())
13971412
13981413 mlp_module .fc2 = PatchedLoraProjection (
1399- mlp_module .fc2 , lora_scale , network_alpha = fc2_alpha , rank = rank , dtype = dtype
1414+ mlp_module .fc2 , lora_scale , network_alpha = fc2_alpha , rank = current_rank_fc2 , dtype = dtype
14001415 )
14011416 lora_parameters .extend (mlp_module .fc2 .lora_linear_layer .parameters ())
14021417
0 commit comments