Skip to content

Commit 3d74dc2

Browse files
sayakpaulpcuenca
andauthored
[Examples] Add a training script for SDXL DreamBooth LoRA (huggingface#4016)
* add dreambooth lora script for SDXL incorporating latest changes. * remove use_auth_token=True. * add: documentation * remove unneeded cli. * increase the number of training steps in the readme. * add LoraLoaderMixin to the subclassing mix. * add sdxl lora dreambooth test. * add: inference code sample. * add: refiner output. * add LoraLoaderMixin to the mix of classes of StableDiffusionXLImg2ImgPipeline. * change default resolution of DreamBoothDataset. * better sdxl report path. * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent dfd7eaf commit 3d74dc2

File tree

8 files changed

+1493
-2
lines changed

8 files changed

+1493
-2
lines changed

docs/source/en/training/dreambooth.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,3 +701,7 @@ accelerate launch train_dreambooth.py \
701701
--class_labels_conditioning timesteps \
702702
--push_to_hub
703703
```
704+
705+
## Stable Diffusion XL
706+
707+
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md).

examples/dreambooth/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,3 +737,7 @@ accelerate launch train_dreambooth.py \
737737
--class_labels_conditioning timesteps \
738738
--push_to_hub
739739
```
740+
741+
## Stable Diffusion XL
742+
743+
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).

examples/dreambooth/README_sdxl.md

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# DreamBooth training example for Stable Diffusion XL (SDXL)
2+
3+
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
4+
5+
The `train_dreambooth_lora_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
6+
7+
> 💡 **Note**: For now, we only allow DreamBooth fine-tuning of the SDXL UNet via LoRA. LoRA is a parameter-efficient fine-tuning technique introduced in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
8+
9+
## Running locally with PyTorch
10+
11+
### Installing the dependencies
12+
13+
Before running the scripts, make sure to install the library's training dependencies:
14+
15+
**Important**
16+
17+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
18+
19+
```bash
20+
git clone https://github.com/huggingface/diffusers
21+
cd diffusers
22+
pip install -e .
23+
```
24+
25+
Then cd in the `examples/dreambooth` folder and run
26+
```bash
27+
pip install -r requirements_sdxl.txt
28+
```
29+
30+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
31+
32+
```bash
33+
accelerate config
34+
```
35+
36+
Or for a default accelerate configuration without answering questions about your environment
37+
38+
```bash
39+
accelerate config default
40+
```
41+
42+
Or if your environment doesn't support an interactive shell (e.g., a notebook)
43+
44+
```python
45+
from accelerate.utils import write_basic_config
46+
write_basic_config()
47+
```
48+
49+
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
50+
51+
### Dog toy example
52+
53+
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
54+
55+
Let's first download it locally:
56+
57+
```python
58+
from huggingface_hub import snapshot_download
59+
60+
local_dir = "./dog"
61+
snapshot_download(
62+
"diffusers/dog-example",
63+
local_dir=local_dir, repo_type="dataset",
64+
ignore_patterns=".gitattributes",
65+
)
66+
```
67+
68+
Since SDXL 0.9 weights are gated, we need to be authenticated to be able to use them. So, let's run:
69+
70+
```bash
71+
huggingface-cli login
72+
```
73+
74+
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
75+
76+
Now, we can launch training using:
77+
78+
```bash
79+
export MODEL_NAME="diffusers/stable-diffusion-xl-base-0.9"
80+
export INSTANCE_DIR="dog"
81+
export OUTPUT_DIR="lora-trained-xl"
82+
83+
accelerate launch train_dreambooth_lora_sdxl.py \
84+
--pretrained_model_name_or_path=$MODEL_NAME \
85+
--instance_data_dir=$INSTANCE_DIR \
86+
--output_dir=$OUTPUT_DIR \
87+
--mixed_precision="fp16" \
88+
--instance_prompt="a photo of sks dog" \
89+
--resolution=1024 \
90+
--train_batch_size=1 \
91+
--gradient_accumulation_steps=4 \
92+
--learning_rate=1e-4 \
93+
--report_to="wandb" \
94+
--lr_scheduler="constant" \
95+
--lr_warmup_steps=0 \
96+
--max_train_steps=500 \
97+
--validation_prompt="A photo of sks dog in a bucket" \
98+
--validation_epochs=25 \
99+
--seed="0" \
100+
--push_to_hub
101+
```
102+
103+
To better track our training experiments, we're using the following flags in the command above:
104+
105+
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
106+
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
107+
108+
Our experiments were conducted on a single 40GB A100 GPU.
109+
110+
### Inference
111+
112+
Once training is done, we can perform inference like so:
113+
114+
```python
115+
from huggingface_hub.repocard import RepoCard
116+
from diffusers import DiffusionPipeline
117+
import torch
118+
119+
lora_model_id = <"lora-sdxl-dreambooth-id">
120+
card = RepoCard.load(lora_model_id)
121+
base_model_id = card.data.to_dict()["base_model"]
122+
123+
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
124+
pipe = pipe.to("cuda")
125+
pipe.load_lora_weights(lora_model_id)
126+
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
127+
image.save("sks_dog.png")
128+
```
129+
130+
We can further refine the outputs with the [Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9):
131+
132+
```python
133+
from huggingface_hub.repocard import RepoCard
134+
from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
135+
import torch
136+
137+
lora_model_id = <"lora-sdxl-dreambooth-id">
138+
card = RepoCard.load(lora_model_id)
139+
base_model_id = card.data.to_dict()["base_model"]
140+
141+
# Load the base pipeline and load the LoRA parameters into it.
142+
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
143+
pipe = pipe.to("cuda")
144+
pipe.load_lora_weights(lora_model_id)
145+
146+
# Load the refiner.
147+
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
148+
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
149+
)
150+
refiner.to("cuda")
151+
152+
prompt = "A picture of a sks dog in a bucket"
153+
generator = torch.Generator("cuda").manual_seed(0)
154+
155+
# Run inference.
156+
image = pipe(prompt=prompt, output_type="latent", generator=generator).images[0]
157+
image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]
158+
image.save("refined_sks_dog.png")
159+
```
160+
161+
Here's a side-by-side comparison of the with and without Refiner pipeline outputs:
162+
163+
| Without Refiner | With Refiner |
164+
|---|---|
165+
| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |
166+
167+
## Notes
168+
169+
In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
170+
171+
## Results
172+
173+
You can explore the results from a couple of our internal experiments by checking out this link: [https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl](https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl). Specifically, we used the same script with the exact same hyperparameters on the following datasets:
174+
175+
* [Dogs](https://huggingface.co/datasets/diffusers/dog-example)
176+
* [Starbucks logo](https://huggingface.co/datasets/diffusers/starbucks-example)
177+
* [Mr. Potato Head](https://huggingface.co/datasets/diffusers/potato-head-example)
178+
* [Keramer face](https://huggingface.co/datasets/diffusers/keramer-face-example)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
invisible-watermark>=2.0

0 commit comments

Comments
 (0)