Skip to content

Commit 07eac4d

Browse files
authored
Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (huggingface#5893)
* Fix bug related to parsing unet_time_cond_proj_dim. * Fix analogous bug in the SD-XL LCM distillation script.
1 parent c079cae commit 07eac4d

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

examples/consistency_distillation/train_lcm_distill_sd_wds.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,15 @@ def parse_args():
657657
default=0.001,
658658
help="The huber loss parameter. Only used if `--loss_type=huber`.",
659659
)
660+
parser.add_argument(
661+
"--unet_time_cond_proj_dim",
662+
type=int,
663+
default=256,
664+
help=(
665+
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
666+
" does not have `time_cond_proj_dim` set."
667+
),
668+
)
660669
# ----Exponential Moving Average (EMA)----
661670
parser.add_argument(
662671
"--ema_decay",
@@ -1138,7 +1147,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11381147

11391148
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
11401149
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
1141-
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
1150+
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
11421151
w = w.reshape(bsz, 1, 1, 1)
11431152
# Move to U-Net device and dtype
11441153
w = w.to(device=latents.device, dtype=latents.dtype)

examples/consistency_distillation/train_lcm_distill_sdxl_wds.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,15 @@ def parse_args():
677677
default=0.001,
678678
help="The huber loss parameter. Only used if `--loss_type=huber`.",
679679
)
680+
parser.add_argument(
681+
"--unet_time_cond_proj_dim",
682+
type=int,
683+
default=256,
684+
help=(
685+
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
686+
" does not have `time_cond_proj_dim` set."
687+
),
688+
)
680689
# ----Exponential Moving Average (EMA)----
681690
parser.add_argument(
682691
"--ema_decay",
@@ -1233,6 +1242,7 @@ def compute_embeddings(
12331242

12341243
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
12351244
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
1245+
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
12361246
w = w.reshape(bsz, 1, 1, 1)
12371247
w = w.to(device=latents.device, dtype=latents.dtype)
12381248

@@ -1243,7 +1253,7 @@ def compute_embeddings(
12431253
noise_pred = unet(
12441254
noisy_model_input,
12451255
start_timesteps,
1246-
timestep_cond=None,
1256+
timestep_cond=w_embedding,
12471257
encoder_hidden_states=prompt_embeds.float(),
12481258
added_cond_kwargs=encoded_text,
12491259
).sample
@@ -1308,7 +1318,7 @@ def compute_embeddings(
13081318
target_noise_pred = target_unet(
13091319
x_prev.float(),
13101320
timesteps,
1311-
timestep_cond=None,
1321+
timestep_cond=w_embedding,
13121322
encoder_hidden_states=prompt_embeds.float(),
13131323
added_cond_kwargs=encoded_text,
13141324
).sample

0 commit comments

Comments
 (0)