-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Textual inversion] Refactor textual inversion to make it cleaner #5076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -2598,7 +2649,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |||
|
|||
is_model_cpu_offload = False | |||
is_sequential_cpu_offload = False | |||
recursive = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to introduce a new concept if only used in one function call with recursive=...
For reviewers: It might be helpful to not just look at the diff, but also check how the new function looks like as a stand-alone file |
# 3. Check inputs | ||
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens) | ||
|
||
# 4. Load state dicts of textual embeddings | ||
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) | ||
|
||
# 4. Retrieve tokens and embeddings | ||
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer) | ||
|
||
# 5. Extend tokens and embeddings for multi vector | ||
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer) | ||
|
||
# 6. Make sure all embeddings have the correct size | ||
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very very nice.
if len(set(valid_tokens)) < len(valid_tokens): | ||
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") | ||
|
||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to mark the function as a bit more light weight and allow it to be used just with the class instance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me. Thanks for the clean up!
Happy to do this for the LoRA loader, but maybe we want to wait till peft
?
…ggingface#5076) * [Textual inversion] Clean loading * [Textual inversion] Clean loading * [Textual inversion] Clean up * [Textual inversion] Clean up * [Textual inversion] Clean up * [Textual inversion] Clean up
…ggingface#5076) * [Textual inversion] Clean loading * [Textual inversion] Clean loading * [Textual inversion] Clean up * [Textual inversion] Clean up * [Textual inversion] Clean up * [Textual inversion] Clean up
What does this PR do?
This PR includes a badly needed refactor of the load_textual_inversion function. The function is now much easier to read & understand. Different concepts are factored out into smaller functions making the overall method much easier to understand.
This PR also makes sure that we remove the model hooks only if we can be sure that the loading will work meaning right before we set the textual inversion embedding. This solves the following issue: #5060