You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This guide will show you how to train a [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model with Textual Inversion. All the training scripts for Textual Inversion used in this guide can be found [here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) if you're interested in taking a closer look at how things work under the hood.
There is a community-created collection of trained Textual Inversion models in the [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library) which are readily available for inference. Over time, this'll hopefully grow into a useful resource as more concepts are added!
Finally, you try and [install xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers) to reduce your memory footprint with xFormers memory-efficient attention. Once you have xFormers installed, add the `--enable_xformers_memory_efficient_attention` argument to the training script. xFormers is not supported for Flax.
If you want to store your model on the Hub, add the following argument to the training script:
59
63
60
64
```bash
61
-
accelerateconfig
65
+
--push_to_hub
62
66
```
63
67
68
+
## Save and load checkpoints
64
69
65
-
###Cattoyexample
70
+
It is often a good idea to regularly save checkpoints of your model during training. This way, you can resume training from a saved checkpoint if your training is interrupted for any reason. To save a checkpoint, pass the following argument to the training script to save the full training state in a subfolder in `output_dir` every 500 steps:
For your training dataset, download these [images of a cat statue](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and store them in a directory.
78
85
79
-
<br>
86
+
Set the `MODEL_NAME` environment variable to the model repository id, and the `DATA_DIR` environment variable to the path of the directory containing the images. Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py):
💡 A full training run takes ~1 hour on one V100 GPU. While you're waiting for the training to complete, feel free to check out [how Textual Inversion works](#how-it-works) in the section below if you're curious!
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️
103
117
104
-
Afulltrainingruntakes~1hourononeV100GPU.
118
+
Before you begin, make sure you install the Flax specific dependencies:
105
119
120
+
```bash
121
+
pip install -U -r requirements_flax.txt
122
+
```
106
123
107
-
###Inference
124
+
Then you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py):
If you're interested in following along with your model training progress, you can save the generated images from the training process. Add the following arguments to the training script to enable intermediate logging:
147
+
148
+
-`validation_prompt`, the prompt used to generate samples (this is set to `None` by default and intermediate logging is disabled)
149
+
-`num_validation_images`, the number of sample images to generate
150
+
-`validation_steps`, the number of steps before generating `num_validation_images` from the `validation_prompt`
151
+
152
+
```bash
153
+
--validation_prompt="A <cat-toy> backpack"
154
+
--num_validation_images=4
155
+
--validation_steps=100
156
+
```
157
+
158
+
## Inference
159
+
160
+
Once you have trained a model, you can use it for inference with the [`StableDiffusionPipeline]. Make sure you include the `placeholder_token` in your prompt, in this case, it is `<cat-toy>`.

211
+
<small>Architecture overview from the Textual Inversion <ahref="https://textual-inversion.github.io/">blog post.</a></small>
212
+
213
+
Usually, text prompts are tokenized into an embedding before being passed to a model, which is often a transformer. Textual Inversion does something similar, but it learns a new token embedding, `v*`, from a special token `S*` in the diagram above. The model output is used to condition the diffusion model, which helps the diffusion model understand the prompt and new concepts from just a few example images.
214
+
215
+
To do this, Textual Inversion uses a generator model and noisy versions of the training images. The generator tries to predict less noisy versions of the images, and the token embedding `v*` is optimized based on how well the generator does. If the token embedding successfully captures the new concept, it gives more useful information to the diffusion model and helps create clearer images with less noise. This optimization process typically occurs after several thousand steps of exposure to a variety of prompt and image variants.
0 commit comments