|
| 1 | +# DreamBooth training example for Stable Diffusion XL (SDXL) |
| 2 | + |
| 3 | +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. |
| 4 | + |
| 5 | +The `train_dreambooth_lora_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). |
| 6 | + |
| 7 | +> 💡 **Note**: For now, we only allow DreamBooth fine-tuning of the SDXL UNet via LoRA. LoRA is a parameter-efficient fine-tuning technique introduced in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. |
| 8 | +
|
| 9 | +## Running locally with PyTorch |
| 10 | + |
| 11 | +### Installing the dependencies |
| 12 | + |
| 13 | +Before running the scripts, make sure to install the library's training dependencies: |
| 14 | + |
| 15 | +**Important** |
| 16 | + |
| 17 | +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: |
| 18 | + |
| 19 | +```bash |
| 20 | +git clone https://github.com/huggingface/diffusers |
| 21 | +cd diffusers |
| 22 | +pip install -e . |
| 23 | +``` |
| 24 | + |
| 25 | +Then cd in the `examples/dreambooth` folder and run |
| 26 | +```bash |
| 27 | +pip install -r requirements_sdxl.txt |
| 28 | +``` |
| 29 | + |
| 30 | +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: |
| 31 | + |
| 32 | +```bash |
| 33 | +accelerate config |
| 34 | +``` |
| 35 | + |
| 36 | +Or for a default accelerate configuration without answering questions about your environment |
| 37 | + |
| 38 | +```bash |
| 39 | +accelerate config default |
| 40 | +``` |
| 41 | + |
| 42 | +Or if your environment doesn't support an interactive shell (e.g., a notebook) |
| 43 | + |
| 44 | +```python |
| 45 | +from accelerate.utils import write_basic_config |
| 46 | +write_basic_config() |
| 47 | +``` |
| 48 | + |
| 49 | +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. |
| 50 | + |
| 51 | +### Dog toy example |
| 52 | + |
| 53 | +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. |
| 54 | + |
| 55 | +Let's first download it locally: |
| 56 | + |
| 57 | +```python |
| 58 | +from huggingface_hub import snapshot_download |
| 59 | + |
| 60 | +local_dir = "./dog" |
| 61 | +snapshot_download( |
| 62 | + "diffusers/dog-example", |
| 63 | + local_dir=local_dir, repo_type="dataset", |
| 64 | + ignore_patterns=".gitattributes", |
| 65 | +) |
| 66 | +``` |
| 67 | + |
| 68 | +Since SDXL 0.9 weights are gated, we need to be authenticated to be able to use them. So, let's run: |
| 69 | + |
| 70 | +```bash |
| 71 | +huggingface-cli login |
| 72 | +``` |
| 73 | + |
| 74 | +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. |
| 75 | + |
| 76 | +Now, we can launch training using: |
| 77 | + |
| 78 | +```bash |
| 79 | +export MODEL_NAME="diffusers/stable-diffusion-xl-base-0.9" |
| 80 | +export INSTANCE_DIR="dog" |
| 81 | +export OUTPUT_DIR="lora-trained-xl" |
| 82 | + |
| 83 | +accelerate launch train_dreambooth_lora_sdxl.py \ |
| 84 | + --pretrained_model_name_or_path=$MODEL_NAME \ |
| 85 | + --instance_data_dir=$INSTANCE_DIR \ |
| 86 | + --output_dir=$OUTPUT_DIR \ |
| 87 | + --mixed_precision="fp16" \ |
| 88 | + --instance_prompt="a photo of sks dog" \ |
| 89 | + --resolution=1024 \ |
| 90 | + --train_batch_size=1 \ |
| 91 | + --gradient_accumulation_steps=4 \ |
| 92 | + --learning_rate=1e-4 \ |
| 93 | + --report_to="wandb" \ |
| 94 | + --lr_scheduler="constant" \ |
| 95 | + --lr_warmup_steps=0 \ |
| 96 | + --max_train_steps=500 \ |
| 97 | + --validation_prompt="A photo of sks dog in a bucket" \ |
| 98 | + --validation_epochs=25 \ |
| 99 | + --seed="0" \ |
| 100 | + --push_to_hub |
| 101 | +``` |
| 102 | + |
| 103 | +To better track our training experiments, we're using the following flags in the command above: |
| 104 | + |
| 105 | +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. |
| 106 | +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. |
| 107 | + |
| 108 | +Our experiments were conducted on a single 40GB A100 GPU. |
| 109 | + |
| 110 | +### Inference |
| 111 | + |
| 112 | +Once training is done, we can perform inference like so: |
| 113 | + |
| 114 | +```python |
| 115 | +from huggingface_hub.repocard import RepoCard |
| 116 | +from diffusers import DiffusionPipeline |
| 117 | +import torch |
| 118 | + |
| 119 | +lora_model_id = <"lora-sdxl-dreambooth-id"> |
| 120 | +card = RepoCard.load(lora_model_id) |
| 121 | +base_model_id = card.data.to_dict()["base_model"] |
| 122 | + |
| 123 | +pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) |
| 124 | +pipe = pipe.to("cuda") |
| 125 | +pipe.load_lora_weights(lora_model_id) |
| 126 | +image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] |
| 127 | +image.save("sks_dog.png") |
| 128 | +``` |
| 129 | + |
| 130 | +We can further refine the outputs with the [Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9): |
| 131 | + |
| 132 | +```python |
| 133 | +from huggingface_hub.repocard import RepoCard |
| 134 | +from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline |
| 135 | +import torch |
| 136 | + |
| 137 | +lora_model_id = <"lora-sdxl-dreambooth-id"> |
| 138 | +card = RepoCard.load(lora_model_id) |
| 139 | +base_model_id = card.data.to_dict()["base_model"] |
| 140 | + |
| 141 | +# Load the base pipeline and load the LoRA parameters into it. |
| 142 | +pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) |
| 143 | +pipe = pipe.to("cuda") |
| 144 | +pipe.load_lora_weights(lora_model_id) |
| 145 | + |
| 146 | +# Load the refiner. |
| 147 | +refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
| 148 | + "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16" |
| 149 | +) |
| 150 | +refiner.to("cuda") |
| 151 | + |
| 152 | +prompt = "A picture of a sks dog in a bucket" |
| 153 | +generator = torch.Generator("cuda").manual_seed(0) |
| 154 | + |
| 155 | +# Run inference. |
| 156 | +image = pipe(prompt=prompt, output_type="latent", generator=generator).images[0] |
| 157 | +image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0] |
| 158 | +image.save("refined_sks_dog.png") |
| 159 | +``` |
| 160 | + |
| 161 | +Here's a side-by-side comparison of the with and without Refiner pipeline outputs: |
| 162 | + |
| 163 | +| Without Refiner | With Refiner | |
| 164 | +|---|---| |
| 165 | +|  |  | |
| 166 | + |
| 167 | +## Notes |
| 168 | + |
| 169 | +In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗 |
| 170 | + |
| 171 | +## Results |
| 172 | + |
| 173 | +You can explore the results from a couple of our internal experiments by checking out this link: [https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl](https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl). Specifically, we used the same script with the exact same hyperparameters on the following datasets: |
| 174 | + |
| 175 | +* [Dogs](https://huggingface.co/datasets/diffusers/dog-example) |
| 176 | +* [Starbucks logo](https://huggingface.co/datasets/diffusers/starbucks-example) |
| 177 | +* [Mr. Potato Head](https://huggingface.co/datasets/diffusers/potato-head-example) |
| 178 | +* [Keramer face](https://huggingface.co/datasets/diffusers/keramer-face-example) |
0 commit comments