Skip to content

Commit c6e08ec

Browse files
linoytsabanyiyixuxusayakpaul
authored
[Sd3 Dreambooth LoRA] Add text encoder training for the clip encoders (huggingface#8630)
* add clip text-encoder training * no dora * text encoder traing fixes * text encoder traing fixes * text encoder training fixes * text encoder training fixes * text encoder training fixes * text encoder training fixes * add text_encoder layers to save_lora * style * fix imports * style * fix text encoder * review changes * review changes * review changes * minor change * add lora tag * style * add readme notes * add tests for clip encoders * style * typo * fixes * style * Update tests/lora/test_lora_layers_sd3.py Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/README_sd3.md Co-authored-by: Sayak Paul <[email protected]> * minor readme change --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 4ad7a1f commit c6e08ec

File tree

4 files changed

+351
-54
lines changed

4 files changed

+351
-54
lines changed

examples/dreambooth/README_sd3.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
147147
--push_to_hub
148148
```
149149

150+
### Text Encoder Training
151+
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
152+
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
153+
154+
> [!NOTE]
155+
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
156+
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
157+
158+
To perform DreamBooth LoRA with text-encoder training, run:
159+
```bash
160+
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
161+
export OUTPUT_DIR="trained-sd3-lora"
162+
163+
accelerate launch train_dreambooth_lora_sd3.py \
164+
--pretrained_model_name_or_path=$MODEL_NAME \
165+
--output_dir=$OUTPUT_DIR \
166+
--dataset_name="Norod78/Yarn-art-style" \
167+
--instance_prompt="a photo of TOK yarn art dog" \
168+
--resolution=1024 \
169+
--train_batch_size=1 \
170+
--train_text_encoder\
171+
--gradient_accumulation_steps=1 \
172+
--optimizer="prodigy"\
173+
--learning_rate=1.0 \
174+
--text_encoder_lr=1.0 \
175+
--report_to="wandb" \
176+
--lr_scheduler="constant" \
177+
--lr_warmup_steps=0 \
178+
--max_train_steps=1500 \
179+
--rank=32 \
180+
--seed="0" \
181+
--push_to_hub
182+
```
183+
150184
## Other notes
151185

152186
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.

0 commit comments

Comments
 (0)