Skip to content

Commit 9f8c915

Browse files
authored
[Dreambooth] flax fixes (huggingface#1765)
* Fail if there are less images than the effective batch size. * Remove lr-scheduler arg as it's currently ignored. * Make guidance_scale work for batch_size > 1.
1 parent 8331da4 commit 9f8c915

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,6 @@ def parse_args():
142142
default=False,
143143
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
144144
)
145-
parser.add_argument(
146-
"--lr_scheduler",
147-
type=str,
148-
default="constant",
149-
help=(
150-
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
151-
' "constant", "constant_with_warmup"]'
152-
),
153-
)
154145
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
155146
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
156147
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -429,6 +420,13 @@ def collate_fn(examples):
429420
return batch
430421

431422
total_train_batch_size = args.train_batch_size * jax.local_device_count()
423+
if len(train_dataset) < total_train_batch_size:
424+
raise ValueError(
425+
f"Training batch size is {total_train_batch_size}, but your dataset only contains"
426+
f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
427+
f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
428+
)
429+
432430
train_dataloader = torch.utils.data.DataLoader(
433431
train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
434432
)

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def __call__(
337337
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
338338
if len(prompt_ids.shape) > 2:
339339
# Assume sharded
340-
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
340+
guidance_scale = guidance_scale[:, None]
341341

342342
if jit:
343343
images = _p_generate(

0 commit comments

Comments
 (0)