Skip to content

Commit 52f2128

Browse files
authored
update readme for flax examples (huggingface#1026)
1 parent fbcc383 commit 52f2128

File tree

6 files changed

+179
-109
lines changed

6 files changed

+179
-109
lines changed

examples/dreambooth/README.md

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
55

66

7-
## Running locally
7+
## Running locally with PyTorch
88
### Installing the dependencies
99

1010
Before running the scripts, make sure to install the library's training dependencies:
@@ -58,24 +58,6 @@ accelerate launch train_dreambooth.py \
5858
--max_train_steps=400
5959
```
6060

61-
Or use the Flax implementation if you need a speedup
62-
63-
```bash
64-
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
65-
export INSTANCE_DIR="path-to-instance-images"
66-
export OUTPUT_DIR="path-to-save-model"
67-
68-
python train_dreambooth_flax.py \
69-
--pretrained_model_name_or_path=$MODEL_NAME \
70-
--instance_data_dir=$INSTANCE_DIR \
71-
--output_dir=$OUTPUT_DIR \
72-
--instance_prompt="a photo of sks dog" \
73-
--resolution=512 \
74-
--train_batch_size=1 \
75-
--learning_rate=5e-6 \
76-
--max_train_steps=400
77-
```
78-
7961
### Training with prior-preservation loss
8062

8163
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
@@ -105,28 +87,6 @@ accelerate launch train_dreambooth.py \
10587
--max_train_steps=800
10688
```
10789

108-
Or use the Flax implementation if you need a speedup
109-
110-
```bash
111-
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
112-
export INSTANCE_DIR="path-to-instance-images"
113-
export CLASS_DIR="path-to-class-images"
114-
export OUTPUT_DIR="path-to-save-model"
115-
116-
python train_dreambooth_flax.py \
117-
--pretrained_model_name_or_path=$MODEL_NAME \
118-
--instance_data_dir=$INSTANCE_DIR \
119-
--class_data_dir=$CLASS_DIR \
120-
--output_dir=$OUTPUT_DIR \
121-
--with_prior_preservation --prior_loss_weight=1.0 \
122-
--instance_prompt="a photo of sks dog" \
123-
--class_prompt="a photo of dog" \
124-
--resolution=512 \
125-
--train_batch_size=1 \
126-
--learning_rate=5e-6 \
127-
--num_class_images=200 \
128-
--max_train_steps=800
129-
```
13090

13191
### Training on a 16GB GPU:
13292

@@ -234,7 +194,58 @@ accelerate launch train_dreambooth.py \
234194
--max_train_steps=800
235195
```
236196

237-
Or use the Flax implementation if you need a speedup
197+
### Inference
198+
199+
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
200+
201+
```python
202+
from diffusers import StableDiffusionPipeline
203+
import torch
204+
205+
model_id = "path-to-your-trained-model"
206+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
207+
208+
prompt = "A photo of sks dog in a bucket"
209+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
210+
211+
image.save("dog-bucket.png")
212+
```
213+
214+
215+
## Running with Flax/JAX
216+
217+
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
218+
219+
____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
220+
221+
222+
Before running the scripts, make sure to install the library's training dependencies:
223+
224+
```bash
225+
pip install -U -r requirements_flax.txt
226+
```
227+
228+
229+
### Training without prior preservation loss
230+
231+
```bash
232+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
233+
export INSTANCE_DIR="path-to-instance-images"
234+
export OUTPUT_DIR="path-to-save-model"
235+
236+
python train_dreambooth_flax.py \
237+
--pretrained_model_name_or_path=$MODEL_NAME \
238+
--instance_data_dir=$INSTANCE_DIR \
239+
--output_dir=$OUTPUT_DIR \
240+
--instance_prompt="a photo of sks dog" \
241+
--resolution=512 \
242+
--train_batch_size=1 \
243+
--learning_rate=5e-6 \
244+
--max_train_steps=400
245+
```
246+
247+
248+
### Training with prior preservation loss
238249

239250
```bash
240251
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
@@ -244,7 +255,6 @@ export OUTPUT_DIR="path-to-save-model"
244255

245256
python train_dreambooth_flax.py \
246257
--pretrained_model_name_or_path=$MODEL_NAME \
247-
--train_text_encoder \
248258
--instance_data_dir=$INSTANCE_DIR \
249259
--class_data_dir=$CLASS_DIR \
250260
--output_dir=$OUTPUT_DIR \
@@ -253,24 +263,32 @@ python train_dreambooth_flax.py \
253263
--class_prompt="a photo of dog" \
254264
--resolution=512 \
255265
--train_batch_size=1 \
256-
--learning_rate=2e-6 \
266+
--learning_rate=5e-6 \
257267
--num_class_images=200 \
258268
--max_train_steps=800
259269
```
260270

261-
## Inference
262271

263-
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
264-
265-
```python
266-
from diffusers import StableDiffusionPipeline
267-
import torch
268-
269-
model_id = "path-to-your-trained-model"
270-
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
272+
### Fine-tune text encoder with the UNet.
271273

272-
prompt = "A photo of sks dog in a bucket"
273-
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
274+
```bash
275+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
276+
export INSTANCE_DIR="path-to-instance-images"
277+
export CLASS_DIR="path-to-class-images"
278+
export OUTPUT_DIR="path-to-save-model"
274279

275-
image.save("dog-bucket.png")
276-
```
280+
python train_dreambooth_flax.py \
281+
--pretrained_model_name_or_path=$MODEL_NAME \
282+
--train_text_encoder \
283+
--instance_data_dir=$INSTANCE_DIR \
284+
--class_data_dir=$CLASS_DIR \
285+
--output_dir=$OUTPUT_DIR \
286+
--with_prior_preservation --prior_loss_weight=1.0 \
287+
--instance_prompt="a photo of sks dog" \
288+
--class_prompt="a photo of dog" \
289+
--resolution=512 \
290+
--train_batch_size=1 \
291+
--learning_rate=2e-6 \
292+
--num_class_images=200 \
293+
--max_train_steps=800
294+
```
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
diffusers>==0.5.1
2+
transformers>=4.21.0
3+
flax
4+
optax
5+
torch
6+
torchvision
7+
ftfy
8+
tensorboard
9+
modelcards

examples/text_to_image/README.md

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ ___Note___:
77
___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
88

99

10-
## Running locally
10+
## Running locally with PyTorch
1111
### Installing the dependencies
1212

1313
Before running the scripts, make sure to install the library's training dependencies:
@@ -62,24 +62,6 @@ accelerate launch train_text_to_image.py \
6262
--output_dir="sd-pokemon-model"
6363
```
6464

65-
Or use the Flax implementation if you need a speedup
66-
67-
```bash
68-
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
69-
export dataset_name="lambdalabs/pokemon-blip-captions"
70-
71-
python train_text_to_image_flax.py \
72-
--pretrained_model_name_or_path=$MODEL_NAME \
73-
--dataset_name=$dataset_name \
74-
--resolution=512 --center_crop --random_flip \
75-
--train_batch_size=1 \
76-
--mixed_precision="fp16" \
77-
--max_train_steps=15000 \
78-
--learning_rate=1e-05 \
79-
--max_grad_norm=1 \
80-
--output_dir="sd-pokemon-model"
81-
```
82-
8365

8466
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
8567
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
@@ -104,34 +86,68 @@ accelerate launch train_text_to_image.py \
10486
--output_dir="sd-pokemon-model"
10587
```
10688

107-
Or use the Flax implementation if you need a speedup
89+
90+
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
91+
92+
93+
```python
94+
from diffusers import StableDiffusionPipeline
95+
96+
model_path = "path_to_saved_model"
97+
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
98+
pipe.to("cuda")
99+
100+
image = pipe(prompt="yoda").images[0]
101+
image.save("yoda-pokemon.png")
102+
```
103+
104+
105+
106+
## Training with Flax/JAX
107+
108+
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
109+
110+
____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
111+
112+
113+
Before running the scripts, make sure to install the library's training dependencies:
114+
115+
```bash
116+
pip install -U -r requirements_flax.txt
117+
```
108118

109119
```bash
110120
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
111-
export TRAIN_DIR="path_to_your_dataset"
121+
export dataset_name="lambdalabs/pokemon-blip-captions"
112122

113123
python train_text_to_image_flax.py \
114124
--pretrained_model_name_or_path=$MODEL_NAME \
115-
--train_data_dir=$TRAIN_DIR \
125+
--dataset_name=$dataset_name \
116126
--resolution=512 --center_crop --random_flip \
117127
--train_batch_size=1 \
118128
--mixed_precision="fp16" \
119129
--max_train_steps=15000 \
120130
--learning_rate=1e-05 \
121131
--max_grad_norm=1 \
122-
--output_dir="sd-pokemon-model"
132+
--output_dir="sd-pokemon-model"
123133
```
124134

125-
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
126-
127135

128-
```python
129-
from diffusers import StableDiffusionPipeline
136+
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
137+
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
130138

131-
model_path = "path_to_saved_model"
132-
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
133-
pipe.to("cuda")
139+
```bash
140+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
141+
export TRAIN_DIR="path_to_your_dataset"
134142

135-
image = pipe(prompt="yoda").images[0]
136-
image.save("yoda-pokemon.png")
143+
python train_text_to_image_flax.py \
144+
--pretrained_model_name_or_path=$MODEL_NAME \
145+
--train_data_dir=$TRAIN_DIR \
146+
--resolution=512 --center_crop --random_flip \
147+
--train_batch_size=1 \
148+
--mixed_precision="fp16" \
149+
--max_train_steps=15000 \
150+
--learning_rate=1e-05 \
151+
--max_grad_norm=1 \
152+
--output_dir="sd-pokemon-model"
137153
```
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
diffusers>==0.5.1
2+
transformers>=4.21.0
3+
flax
4+
optax
5+
torch
6+
torchvision
7+
ftfy
8+
tensorboard
9+
modelcards

examples/textual_inversion/README.md

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Colab for training
1111
Colab for inference
1212
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
1313

14-
## Running locally
14+
## Running locally with PyTorch
1515
### Installing the dependencies
1616

1717
Before running the scripts, make sure to install the library's training dependencies:
@@ -68,25 +68,6 @@ accelerate launch textual_inversion.py \
6868

6969
A full training run takes ~1 hour on one V100 GPU.
7070

71-
If you want to speed it up even more, Flax implementation is available:
72-
73-
```bash
74-
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
75-
export DATA_DIR="path-to-dir-containing-images"
76-
77-
python textual_inversion_flax.py \
78-
--pretrained_model_name_or_path=$MODEL_NAME \
79-
--train_data_dir=$DATA_DIR \
80-
--learnable_property="object" \
81-
--placeholder_token="<cat-toy>" --initializer_token="toy" \
82-
--resolution=512 \
83-
--train_batch_size=1 \
84-
--max_train_steps=3000 \
85-
--learning_rate=5.0e-04 --scale_lr \
86-
--output_dir="textual_inversion_cat"
87-
```
88-
It should be at least 70% faster than the PyTorch script with the same configuration.
89-
9071
### Inference
9172

9273
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
@@ -103,3 +84,31 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
10384

10485
image.save("cat-backpack.png")
10586
```
87+
88+
89+
## Training with Flax/JAX
90+
91+
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
92+
93+
Before running the scripts, make sure to install the library's training dependencies:
94+
95+
```bash
96+
pip install -U -r requirements_flax.txt
97+
```
98+
99+
```bash
100+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
101+
export DATA_DIR="path-to-dir-containing-images"
102+
103+
python textual_inversion_flax.py \
104+
--pretrained_model_name_or_path=$MODEL_NAME \
105+
--train_data_dir=$DATA_DIR \
106+
--learnable_property="object" \
107+
--placeholder_token="<cat-toy>" --initializer_token="toy" \
108+
--resolution=512 \
109+
--train_batch_size=1 \
110+
--max_train_steps=3000 \
111+
--learning_rate=5.0e-04 --scale_lr \
112+
--output_dir="textual_inversion_cat"
113+
```
114+
It should be at least 70% faster than the PyTorch script with the same configuration.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
diffusers>==0.5.1
2+
transformers>=4.21.0
3+
flax
4+
optax
5+
torch
6+
torchvision
7+
ftfy
8+
tensorboard
9+
modelcards

0 commit comments

Comments
 (0)