Skip to content

[LoRA] ensure different LoRA ranks for text encoders can be properly handled #4669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 22, 2023

Conversation

sayakpaul
Copy link
Member

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 18, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -1344,7 +1352,7 @@ def _modify_text_encoder(
text_encoder,
lora_scale=1,
network_alphas=None,
rank=4,
rank: Union[Dict[str, int], int] = 4,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To have backward compatibility in our training scripts.

Comment on lines +1405 to +1414
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")

mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())

mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never allow patching the MLP from our training scripts. So, this should be okay.

Comment on lines -1286 to +1297
rank = text_encoder_lora_state_dict[
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
].shape[1]
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})

patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it instead be possible to use _register_state_dict_pre_hook on LoRALinearLayer so they can look at the incoming weights when the state dict is loaded and change the internal weights to the appropriate shape? This allows us to treat the state dict more transparently and avoid having to construct a rank dict by looking at strings in the passed in state dict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows us to treat the state dict more transparently and avoid having to construct a rank dict by looking at strings in the passed in state dict.

_register_state_dict_pre_hook will also need looking at the state dicts if we were to retrieve the ranks no?

What you're suggesting isn't clear to me. So, need some elaboration.

@sayakpaul
Copy link
Member Author

Merging after internal discussions with Will on Slack.

@sayakpaul sayakpaul merged commit 1e0395e into main Aug 22, 2023
@sayakpaul sayakpaul deleted the fix/text-encoder-lora-sdxl-2 branch August 22, 2023 02:51
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…handled (huggingface#4669)

* debugging starts

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging ends, but does it?

* more robustness.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…handled (huggingface#4669)

* debugging starts

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging ends, but does it?

* more robustness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants