-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[LoRA] ensure different LoRA ranks for text encoders can be properly handled #4669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -1344,7 +1352,7 @@ def _modify_text_encoder( | |||
text_encoder, | |||
lora_scale=1, | |||
network_alphas=None, | |||
rank=4, | |||
rank: Union[Dict[str, int], int] = 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To have backward compatibility in our training scripts.
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") | ||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") | ||
|
||
mlp_module.fc1 = PatchedLoraProjection( | ||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype | ||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype | ||
) | ||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) | ||
|
||
mlp_module.fc2 = PatchedLoraProjection( | ||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype | ||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never allow patching the MLP from our training scripts. So, this should be okay.
rank = text_encoder_lora_state_dict[ | ||
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" | ||
].shape[1] | ||
for name, _ in text_encoder_attn_modules(text_encoder): | ||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" | ||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) | ||
|
||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) | ||
if patch_mlp: | ||
for name, _ in text_encoder_mlp_modules(text_encoder): | ||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" | ||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" | ||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) | ||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it instead be possible to use _register_state_dict_pre_hook
on LoRALinearLayer so they can look at the incoming weights when the state dict is loaded and change the internal weights to the appropriate shape? This allows us to treat the state dict more transparently and avoid having to construct a rank dict by looking at strings in the passed in state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows us to treat the state dict more transparently and avoid having to construct a rank dict by looking at strings in the passed in state dict.
_register_state_dict_pre_hook
will also need looking at the state dicts if we were to retrieve the ranks no?
What you're suggesting isn't clear to me. So, need some elaboration.
Merging after internal discussions with Will on Slack. |
…handled (huggingface#4669) * debugging starts * debugging * debugging * debugging * debugging * debugging * debugging ends, but does it? * more robustness.
…handled (huggingface#4669) * debugging starts * debugging * debugging * debugging * debugging * debugging * debugging ends, but does it? * more robustness.
Internal thread: https://huggingface.slack.com/archives/C03UQJENJTV/p1692343210344049