Skip to content

Commit ccb93dc

Browse files
Support EDM-style training in DreamBooth LoRA SDXL script (huggingface#7126)
* add: dreambooth lora script for Playground v2.5 * fix: kwarg * address suraj's comments. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * apply suraj's suggestion * incorporate changes in the canonical script./ * tracker naming * fix: schedule determination * add: two simple tests * remove playground script * note about edm-style training * address pedro's comments. * address part of Suraj's comments. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * remove guidance_scale. * use mse_loss. * add comments for preconditioning. * quality * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Suraj Patil <[email protected]> * tackle v-pred. * Empty-Commit * support edm for sdxl too. * address suraj's comments. * Empty-Commit --------- Co-authored-by: Suraj Patil <[email protected]>
1 parent ec95304 commit ccb93dc

File tree

3 files changed

+307
-29
lines changed

3 files changed

+307
-29
lines changed

examples/dreambooth/README_sdxl.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin
206206
## Running on a free-tier Colab Notebook
207207

208208
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
209+
210+
## Conducting EDM-style training
211+
212+
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
213+
214+
For the SDXL model, simple set:
215+
216+
```diff
217+
+ --do_edm_style_training \
218+
```
219+
220+
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
221+
222+
```bash
223+
accelerate launch train_dreambooth_lora_sdxl.py \
224+
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
225+
--instance_data_dir="dog" \
226+
--output_dir="dog-playground-lora" \
227+
--mixed_precision="fp16" \
228+
--instance_prompt="a photo of sks dog" \
229+
--resolution=1024 \
230+
--train_batch_size=1 \
231+
--gradient_accumulation_steps=4 \
232+
--learning_rate=1e-4 \
233+
--use_8bit_adam \
234+
--report_to="wandb" \
235+
--lr_scheduler="constant" \
236+
--lr_warmup_steps=0 \
237+
--max_train_steps=500 \
238+
--validation_prompt="A photo of sks dog in a bucket" \
239+
--validation_epochs=25 \
240+
--seed="0" \
241+
--push_to_hub
242+
```
243+
244+
> [!CAUTION]
245+
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
36+
def test_dreambooth_lora_sdxl_with_edm(self):
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
test_args = f"""
39+
examples/dreambooth/train_dreambooth_lora_sdxl.py
40+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
41+
--do_edm_style_training
42+
--instance_data_dir docs/source/en/imgs
43+
--instance_prompt photo
44+
--resolution 64
45+
--train_batch_size 1
46+
--gradient_accumulation_steps 1
47+
--max_train_steps 2
48+
--learning_rate 5.0e-04
49+
--scale_lr
50+
--lr_scheduler constant
51+
--lr_warmup_steps 0
52+
--output_dir {tmpdir}
53+
""".split()
54+
55+
run_command(self._launch_args + test_args)
56+
# save_pretrained smoke test
57+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
58+
59+
# make sure the state_dict has the correct naming in the parameters.
60+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
61+
is_lora = all("lora" in k for k in lora_state_dict.keys())
62+
self.assertTrue(is_lora)
63+
64+
# when not training the text encoder, all the parameters in the state dict should start
65+
# with `"unet"` in their names.
66+
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
67+
self.assertTrue(starts_with_unet)
68+
69+
def test_dreambooth_lora_playground(self):
70+
with tempfile.TemporaryDirectory() as tmpdir:
71+
test_args = f"""
72+
examples/dreambooth/train_dreambooth_lora_sdxl.py
73+
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
74+
--instance_data_dir docs/source/en/imgs
75+
--instance_prompt photo
76+
--resolution 64
77+
--train_batch_size 1
78+
--gradient_accumulation_steps 1
79+
--max_train_steps 2
80+
--learning_rate 5.0e-04
81+
--scale_lr
82+
--lr_scheduler constant
83+
--lr_warmup_steps 0
84+
--output_dir {tmpdir}
85+
""".split()
86+
87+
run_command(self._launch_args + test_args)
88+
# save_pretrained smoke test
89+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
90+
91+
# make sure the state_dict has the correct naming in the parameters.
92+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
93+
is_lora = all("lora" in k for k in lora_state_dict.keys())
94+
self.assertTrue(is_lora)
95+
96+
# when not training the text encoder, all the parameters in the state dict should start
97+
# with `"unet"` in their names.
98+
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
99+
self.assertTrue(starts_with_unet)

0 commit comments

Comments
 (0)