Skip to content

[Textual inversion] Refactor textual inversion to make it cleaner #5076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 18, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Sep 17, 2023

What does this PR do?

This PR includes a badly needed refactor of the load_textual_inversion function. The function is now much easier to read & understand. Different concepts are factored out into smaller functions making the overall method much easier to understand.

This PR also makes sure that we remove the model hooks only if we can be sure that the loading will work meaning right before we set the textual inversion embedding. This solves the following issue: #5060

@@ -2598,7 +2649,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di

is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to introduce a new concept if only used in one function call with recursive=...

@patrickvonplaten patrickvonplaten changed the title Clean text inv loading [Textual inversion] Refactor textual inversion to make it cleaner Sep 17, 2023
@patrickvonplaten
Copy link
Contributor Author

For reviewers: It might be helpful to not just look at the diff, but also check how the new function looks like as a stand-alone file

Comment on lines +971 to +984
# 3. Check inputs
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)

# 4. Load state dicts of textual embeddings
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)

# 4. Retrieve tokens and embeddings
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)

# 5. Extend tokens and embeddings for multi vector
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)

# 6. Make sure all embeddings have the correct size
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very nice.

if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")

@staticmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why static?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to mark the function as a bit more light weight and allow it to be used just with the class instance

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me. Thanks for the clean up!

Happy to do this for the LoRA loader, but maybe we want to wait till peft?

@patrickvonplaten patrickvonplaten merged commit 7b39f43 into main Sep 18, 2023
patrickvonplaten added a commit that referenced this pull request Sep 18, 2023
)

* [Textual inversion] Clean loading

* [Textual inversion] Clean loading

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up
@kashif kashif deleted the clean_text_inv_loading branch September 29, 2023 11:39
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ggingface#5076)

* [Textual inversion] Clean loading

* [Textual inversion] Clean loading

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ggingface#5076)

* [Textual inversion] Clean loading

* [Textual inversion] Clean loading

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants