Skip to content

Commit cdf2ae8

Browse files
authored
[Enhance] Add LoRA rank args in train_text_to_image_lora (huggingface#3866)
* add rank args in lora finetune * del network_alpha
1 parent 49949f3 commit cdf2ae8

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,12 @@ def parse_args():
343343
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
344344
)
345345
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
346+
parser.add_argument(
347+
"--rank",
348+
type=int,
349+
default=4,
350+
help=("The dimension of the LoRA update matrices."),
351+
)
346352

347353
args = parser.parse_args()
348354
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -464,7 +470,11 @@ def main():
464470
block_id = int(name[len("down_blocks.")])
465471
hidden_size = unet.config.block_out_channels[block_id]
466472

467-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
473+
lora_attn_procs[name] = LoRAAttnProcessor(
474+
hidden_size=hidden_size,
475+
cross_attention_dim=cross_attention_dim,
476+
rank=args.rank,
477+
)
468478

469479
unet.set_attn_processor(lora_attn_procs)
470480

0 commit comments

Comments
 (0)