Skip to content

Commit 4623f09

Browse files
authored
[DreamBooth] Set train mode for text encoder (huggingface#1012)
Set train mode for text encoder
1 parent abe0582 commit 4623f09

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def collate_fn(examples):
574574

575575
for epoch in range(args.num_train_epochs):
576576
unet.train()
577+
if args.train_text_encoder:
578+
text_encoder.train()
577579
for step, batch in enumerate(train_dataloader):
578580
with accelerator.accumulate(unet):
579581
# Convert images to latent space

0 commit comments

Comments
 (0)