Skip to content

Commit 04ddad4

Browse files
authored
Add 'rank' parameter to Dreambooth LoRA training script (huggingface#3945)
1 parent 03d829d commit 04ddad4

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ def parse_args(input_args=None):
436436
default=None,
437437
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
438438
)
439+
parser.add_argument(
440+
"--rank",
441+
type=int,
442+
default=4,
443+
help=("The dimension of the LoRA update matrices."),
444+
)
439445

440446
if input_args is not None:
441447
args = parser.parse_args(input_args)
@@ -845,7 +851,9 @@ def main(args):
845851
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
846852
)
847853
unet_lora_attn_procs[name] = lora_attn_processor_class(
848-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
854+
hidden_size=hidden_size,
855+
cross_attention_dim=cross_attention_dim,
856+
rank=args.rank,
849857
)
850858

851859
unet.set_attn_processor(unet_lora_attn_procs)
@@ -860,7 +868,9 @@ def main(args):
860868
for name, module in text_encoder.named_modules():
861869
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
862870
text_lora_attn_procs[name] = LoRAAttnProcessor(
863-
hidden_size=module.out_proj.out_features, cross_attention_dim=None
871+
hidden_size=module.out_proj.out_features,
872+
cross_attention_dim=None,
873+
rank=args.rank,
864874
)
865875
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
866876
temp_pipeline = DiffusionPipeline.from_pretrained(

0 commit comments

Comments
 (0)