Skip to content

Commit b5814c5

Browse files
authored
add DoRA training feature to sdxl dreambooth lora script (huggingface#7235)
* dora in canonical script * add mention of DoRA to readme
1 parent 9940573 commit b5814c5

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

examples/dreambooth/README_sdxl.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,29 @@ accelerate launch train_dreambooth_lora_sdxl.py \
243243

244244
> [!CAUTION]
245245
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
246+
247+
### DoRA training
248+
The script now supports DoRA training too!
249+
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
250+
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
251+
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
252+
253+
> [!NOTE]
254+
> 💡DoRA training is still _experimental_
255+
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
256+
> Specifically, we've noticed 2 differences to take into account your training:
257+
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
258+
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
259+
> This is also aligned with some of the quantitative analysis shown in the paper.
260+
261+
**Usage**
262+
1. To use DoRA you need to install `peft` from main:
263+
```bash
264+
pip install git+https://github.com/huggingface/peft.git
265+
```
266+
2. Enable DoRA training by adding this flag
267+
```bash
268+
--use_dora
269+
```
270+
**Inference**
271+
The inference is the same as if you train a regular LoRA 🤗

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,15 @@ def parse_args(input_args=None):
647647
default=4,
648648
help=("The dimension of the LoRA update matrices."),
649649
)
650+
parser.add_argument(
651+
"--use_dora",
652+
action="store_true",
653+
default=False,
654+
help=(
655+
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
656+
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
657+
),
658+
)
650659

651660
if input_args is not None:
652661
args = parser.parse_args(input_args)
@@ -1147,6 +1156,7 @@ def main(args):
11471156
# now we will add new LoRA weights to the attention layers
11481157
unet_lora_config = LoraConfig(
11491158
r=args.rank,
1159+
use_dora=args.use_dora,
11501160
lora_alpha=args.rank,
11511161
init_lora_weights="gaussian",
11521162
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@@ -1158,6 +1168,7 @@ def main(args):
11581168
if args.train_text_encoder:
11591169
text_lora_config = LoraConfig(
11601170
r=args.rank,
1171+
use_dora=args.use_dora,
11611172
lora_alpha=args.rank,
11621173
init_lora_weights="gaussian",
11631174
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],

0 commit comments

Comments
 (0)