Skip to content

Commit b285d94

Browse files
authored
[docs] Move Textual Inversion training examples to docs (huggingface#2576)
* 📝 add textual inversion examples to docs * 🖍 apply feedback * 🖍 add colab link
1 parent 55660cf commit b285d94

File tree

1 file changed

+131
-38
lines changed

1 file changed

+131
-38
lines changed

docs/source/en/training/text_inversion.mdx

Lines changed: 131 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,74 +14,85 @@ specific language governing permissions and limitations under the License.
1414

1515
# Textual Inversion
1616

17-
Textual Inversion is a technique for capturing novel concepts from a small number of example images in a way that can later be used to control text-to-image pipelines. It does so by learning new 'words' in the embedding space of the pipeline's text encoder. These special words can then be used within text prompts to achieve very fine-grained control of the resulting images.
17+
[[open-in-colab]]
1818

19-
![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG)
20-
_By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation ([image source](https://github.com/rinongal/textual_inversion))._
19+
[Textual Inversion](https://arxiv.org/abs/2208.01618) is a technique for capturing novel concepts from a small number of example images. While the technique was originally demonstrated with a [latent diffusion model](https://github.com/CompVis/latent-diffusion), it has since been applied to other model variants like [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion). The learned concepts can be used to better control the images generated from text-to-image pipelines. It learns new "words" in the text encoder's embedding space, which are used within text prompts for personalized image generation.
2120

22-
This technique was introduced in [An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion](https://arxiv.org/abs/2208.01618). The paper demonstrated the concept using a [latent diffusion model](https://github.com/CompVis/latent-diffusion) but the idea has since been applied to other variants such as [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion).
21+
![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG)
22+
<small>By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation <a href="https://github.com/rinongal/textual_inversion">(image source)</a></small>
2323

24+
This guide will show you how to train a [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model with Textual Inversion. All the training scripts for Textual Inversion used in this guide can be found [here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) if you're interested in taking a closer look at how things work under the hood.
2425

25-
## How It Works
26+
<Tip>
2627

27-
![Diagram from the paper showing overview](https://textual-inversion.github.io/static/images/training/training.JPG)
28-
_Architecture Overview from the [textual inversion blog post](https://textual-inversion.github.io/)_
28+
There is a community-created collection of trained Textual Inversion models in the [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library) which are readily available for inference. Over time, this'll hopefully grow into a useful resource as more concepts are added!
2929

30-
Before a text prompt can be used in a diffusion model, it must first be processed into a numerical representation. This typically involves tokenizing the text, converting each token to an embedding and then feeding those embeddings through a model (typically a transformer) whose output will be used as the conditioning for the diffusion model.
30+
</Tip>
3131

32-
Textual inversion learns a new token embedding (v* in the diagram above). A prompt (that includes a token which will be mapped to this new embedding) is used in conjunction with a noised version of one or more training images as inputs to the generator model, which attempts to predict the denoised version of the image. The embedding is optimized based on how well the model does at this task - an embedding that better captures the object or style shown by the training images will give more useful information to the diffusion model and thus result in a lower denoising loss. After many steps (typically several thousand) with a variety of prompt and image variants the learned embedding should hopefully capture the essence of the new concept being taught.
32+
Before you begin, make sure you install the library's training dependencies:
3333

34-
## Usage
34+
```bash
35+
pip install diffusers accelerate transformers
36+
```
3537

36-
To train your own textual inversions, see the [example script here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion).
38+
After all the dependencies have been set up, initialize a [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
3739

38-
There is also a notebook for training:
39-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
40+
```bash
41+
accelerate config
42+
```
4043

41-
And one for inference:
42-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
44+
To setup a default 🤗 Accelerate environment without choosing any configurations:
4345

44-
In addition to using concepts you have trained yourself, there is a community-created collection of trained textual inversions in the new [Stable Diffusion public concepts library](https://huggingface.co/sd-concepts-library) which you can also use from the inference notebook above. Over time this will hopefully grow into a useful resource as more examples are added.
46+
```bash
47+
accelerate config default
48+
```
4549

46-
## Example: Running locally
50+
Or if your environment doesn't support an interactive shell like a notebook, you can use:
4751

48-
The `textual_inversion.py` script [here](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion) shows how to implement the training procedure and adapt it for stable diffusion.
52+
```bash
53+
from accelerate.utils import write_basic_config
4954

50-
### Installing the dependencies
55+
write_basic_config()
56+
```
5157

52-
Before running the scripts, make sure to install the library's training dependencies.
58+
Finally, you try and [install xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers) to reduce your memory footprint with xFormers memory-efficient attention. Once you have xFormers installed, add the `--enable_xformers_memory_efficient_attention` argument to the training script. xFormers is not supported for Flax.
5359

54-
```bash
55-
pip install diffusers[training] accelerate transformers
56-
```
60+
## Upload model to Hub
5761

58-
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
62+
If you want to store your model on the Hub, add the following argument to the training script:
5963

6064
```bash
61-
accelerate config
65+
--push_to_hub
6266
```
6367

68+
## Save and load checkpoints
6469

65-
### Cat toy example
70+
It is often a good idea to regularly save checkpoints of your model during training. This way, you can resume training from a saved checkpoint if your training is interrupted for any reason. To save a checkpoint, pass the following argument to the training script to save the full training state in a subfolder in `output_dir` every 500 steps:
6671

67-
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
68-
69-
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
72+
```bash
73+
--checkpointing_steps=500
74+
```
7075

71-
Run the following command to authenticate your token
76+
To resume training from a saved checkpoint, pass the following argument to the training script and the specific checkpoint you'd like to resume from:
7277

7378
```bash
74-
huggingface-cli login
79+
--resume_from_checkpoint="checkpoint-1500"
7580
```
7681

77-
If you have already cloned the repo, then you won't need to go through these steps.
82+
## Finetuning
83+
84+
For your training dataset, download these [images of a cat statue](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and store them in a directory.
7885

79-
<br>
86+
Set the `MODEL_NAME` environment variable to the model repository id, and the `DATA_DIR` environment variable to the path of the directory containing the images. Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py):
8087

81-
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
88+
<Tip>
8289

83-
And launch the training using
90+
💡 A full training run takes ~1 hour on one V100 GPU. While you're waiting for the training to complete, feel free to check out [how Textual Inversion works](#how-it-works) in the section below if you're curious!
8491

92+
</Tip>
93+
94+
<frameworkcontent>
95+
<pt>
8596
```bash
8697
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
8798
export DATA_DIR="path-to-dir-containing-images"
@@ -100,14 +111,56 @@ accelerate launch textual_inversion.py \
100111
--lr_warmup_steps=0 \
101112
--output_dir="textual_inversion_cat"
102113
```
114+
</pt>
115+
<jax>
116+
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️
103117

104-
A full training run takes ~1 hour on one V100 GPU.
118+
Before you begin, make sure you install the Flax specific dependencies:
105119

120+
```bash
121+
pip install -U -r requirements_flax.txt
122+
```
106123

107-
### Inference
124+
Then you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py):
108125

109-
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
126+
```bash
127+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
128+
export DATA_DIR="path-to-dir-containing-images"
110129

130+
python textual_inversion_flax.py \
131+
--pretrained_model_name_or_path=$MODEL_NAME \
132+
--train_data_dir=$DATA_DIR \
133+
--learnable_property="object" \
134+
--placeholder_token="<cat-toy>" --initializer_token="toy" \
135+
--resolution=512 \
136+
--train_batch_size=1 \
137+
--max_train_steps=3000 \
138+
--learning_rate=5.0e-04 --scale_lr \
139+
--output_dir="textual_inversion_cat"
140+
```
141+
</jax>
142+
</frameworkcontent>
143+
144+
### Intermediate logging
145+
146+
If you're interested in following along with your model training progress, you can save the generated images from the training process. Add the following arguments to the training script to enable intermediate logging:
147+
148+
- `validation_prompt`, the prompt used to generate samples (this is set to `None` by default and intermediate logging is disabled)
149+
- `num_validation_images`, the number of sample images to generate
150+
- `validation_steps`, the number of steps before generating `num_validation_images` from the `validation_prompt`
151+
152+
```bash
153+
--validation_prompt="A <cat-toy> backpack"
154+
--num_validation_images=4
155+
--validation_steps=100
156+
```
157+
158+
## Inference
159+
160+
Once you have trained a model, you can use it for inference with the [`StableDiffusionPipeline]. Make sure you include the `placeholder_token` in your prompt, in this case, it is `<cat-toy>`.
161+
162+
<frameworkcontent>
163+
<pt>
111164
```python
112165
from diffusers import StableDiffusionPipeline
113166

@@ -120,3 +173,43 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
120173

121174
image.save("cat-backpack.png")
122175
```
176+
</pt>
177+
<jax>
178+
```python
179+
import jax
180+
import numpy as np
181+
from flax.jax_utils import replicate
182+
from flax.training.common_utils import shard
183+
from diffusers import FlaxStableDiffusionPipeline
184+
185+
model_path = "path-to-your-trained-model"
186+
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
187+
188+
prompt = "A <cat-toy> backpack"
189+
prng_seed = jax.random.PRNGKey(0)
190+
num_inference_steps = 50
191+
192+
num_samples = jax.device_count()
193+
prompt = num_samples * [prompt]
194+
prompt_ids = pipeline.prepare_inputs(prompt)
195+
196+
# shard inputs and rng
197+
params = replicate(params)
198+
prng_seed = jax.random.split(prng_seed, jax.device_count())
199+
prompt_ids = shard(prompt_ids)
200+
201+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
202+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
203+
image.save("cat-backpack.png")
204+
```
205+
</jax>
206+
</frameworkcontent>
207+
208+
## How it works
209+
210+
![Diagram from the paper showing overview](https://textual-inversion.github.io/static/images/training/training.JPG)
211+
<small>Architecture overview from the Textual Inversion <a href="https://textual-inversion.github.io/">blog post.</a></small>
212+
213+
Usually, text prompts are tokenized into an embedding before being passed to a model, which is often a transformer. Textual Inversion does something similar, but it learns a new token embedding, `v*`, from a special token `S*` in the diagram above. The model output is used to condition the diffusion model, which helps the diffusion model understand the prompt and new concepts from just a few example images.
214+
215+
To do this, Textual Inversion uses a generator model and noisy versions of the training images. The generator tries to predict less noisy versions of the images, and the token embedding `v*` is optimized based on how well the generator does. If the token embedding successfully captures the new concept, it gives more useful information to the diffusion model and helps create clearer images with less noise. This optimization process typically occurs after several thousand steps of exposure to a variety of prompt and image variants.

0 commit comments

Comments
 (0)