Skip to content

Commit 1e0395e

Browse files
authored
[LoRA] ensure different LoRA ranks for text encoders can be properly handled (huggingface#4669)
* debugging starts * debugging * debugging * debugging * debugging * debugging * debugging ends, but does it? * more robustness.
1 parent 9141c1f commit 1e0395e

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

src/diffusers/loaders.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
12451245

12461246
if len(text_encoder_lora_state_dict) > 0:
12471247
logger.info(f"Loading {prefix}.")
1248+
rank = {}
12481249

12491250
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
12501251
# Convert from the old naming convention to the new naming convention.
@@ -1283,10 +1284,17 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
12831284
f"{name}.out_proj.lora_linear_layer.down.weight"
12841285
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
12851286

1286-
rank = text_encoder_lora_state_dict[
1287-
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
1288-
].shape[1]
1287+
for name, _ in text_encoder_attn_modules(text_encoder):
1288+
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
1289+
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
1290+
12891291
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
1292+
if patch_mlp:
1293+
for name, _ in text_encoder_mlp_modules(text_encoder):
1294+
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
1295+
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
1296+
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
1297+
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
12901298

12911299
if network_alphas is not None:
12921300
alpha_keys = [
@@ -1344,7 +1352,7 @@ def _modify_text_encoder(
13441352
text_encoder,
13451353
lora_scale=1,
13461354
network_alphas=None,
1347-
rank=4,
1355+
rank: Union[Dict[str, int], int] = 4,
13481356
dtype=None,
13491357
patch_mlp=False,
13501358
):
@@ -1365,38 +1373,45 @@ def _modify_text_encoder(
13651373
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
13661374
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
13671375

1376+
if isinstance(rank, dict):
1377+
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
1378+
else:
1379+
current_rank = rank
1380+
13681381
attn_module.q_proj = PatchedLoraProjection(
1369-
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
1382+
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
13701383
)
13711384
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
13721385

13731386
attn_module.k_proj = PatchedLoraProjection(
1374-
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
1387+
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
13751388
)
13761389
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
13771390

13781391
attn_module.v_proj = PatchedLoraProjection(
1379-
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
1392+
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
13801393
)
13811394
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
13821395

13831396
attn_module.out_proj = PatchedLoraProjection(
1384-
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
1397+
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
13851398
)
13861399
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
13871400

13881401
if patch_mlp:
13891402
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
13901403
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
13911404
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
1405+
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
1406+
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
13921407

13931408
mlp_module.fc1 = PatchedLoraProjection(
1394-
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
1409+
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
13951410
)
13961411
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
13971412

13981413
mlp_module.fc2 = PatchedLoraProjection(
1399-
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
1414+
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
14001415
)
14011416
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
14021417

0 commit comments

Comments
 (0)