Skip to content

Commit bdd1611

Browse files
[Schedulers] Fix callback steps (huggingface#5261)
* fix all * make fix copies * make fix copies
1 parent c8b0f0e commit bdd1611

File tree

87 files changed

+187
-91
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+187
-91
lines changed

examples/community/composable_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def __call__(
562562
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
563563
progress_bar.update()
564564
if callback is not None and i % callback_steps == 0:
565-
callback(i, t, latents)
565+
step_idx = i // getattr(self.scheduler, "order", 1)
566+
callback(step_idx, t, latents)
566567

567568
# 8. Post-processing
568569
image = self.decode_latents(latents)

examples/community/img2img_inpainting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,8 @@ def __call__(
434434

435435
# call the callback, if provided
436436
if callback is not None and i % callback_steps == 0:
437-
callback(i, t, latents)
437+
step_idx = i // getattr(self.scheduler, "order", 1)
438+
callback(step_idx, t, latents)
438439

439440
latents = 1 / 0.18215 * latents
440441
image = self.vae.decode(latents).sample

examples/community/interpolate_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ def __call__(
372372

373373
# call the callback, if provided
374374
if callback is not None and i % callback_steps == 0:
375-
callback(i, t, latents)
375+
step_idx = i // getattr(self.scheduler, "order", 1)
376+
callback(step_idx, t, latents)
376377

377378
latents = 1 / 0.18215 * latents
378379
image = self.vae.decode(latents).sample

examples/community/lpw_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,8 @@ def __call__(
10881088
progress_bar.update()
10891089
if i % callback_steps == 0:
10901090
if callback is not None:
1091-
callback(i, t, latents)
1091+
step_idx = i // getattr(self.scheduler, "order", 1)
1092+
callback(step_idx, t, latents)
10921093
if is_cancelled_callback is not None and is_cancelled_callback():
10931094
return None
10941095

examples/community/lpw_stable_diffusion_onnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,8 @@ def __call__(
846846
# call the callback, if provided
847847
if i % callback_steps == 0:
848848
if callback is not None:
849-
callback(i, t, latents)
849+
step_idx = i // getattr(self.scheduler, "order", 1)
850+
callback(step_idx, t, latents)
850851
if is_cancelled_callback is not None and is_cancelled_callback():
851852
return None
852853

examples/community/lpw_stable_diffusion_xl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,8 @@ def __call__(
11821182
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
11831183
progress_bar.update()
11841184
if callback is not None and i % callback_steps == 0:
1185-
callback(i, t, latents)
1185+
step_idx = i // getattr(self.scheduler, "order", 1)
1186+
callback(step_idx, t, latents)
11861187

11871188
if not output_type == "latent":
11881189
# make sure the VAE is in float32 mode, as it overflows in float16

examples/community/masked_stable_diffusion_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def __call__(
202202
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
203203
progress_bar.update()
204204
if callback is not None and i % callback_steps == 0:
205-
callback(i, t, latents)
205+
step_idx = i // getattr(self.scheduler, "order", 1)
206+
callback(step_idx, t, latents)
206207

207208
if not output_type == "latent":
208209
scaled = latents / self.vae.config.scaling_factor

examples/community/multilingual_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,8 @@ def __call__(
407407

408408
# call the callback, if provided
409409
if callback is not None and i % callback_steps == 0:
410-
callback(i, t, latents)
410+
step_idx = i // getattr(self.scheduler, "order", 1)
411+
callback(step_idx, t, latents)
411412

412413
latents = 1 / 0.18215 * latents
413414
image = self.vae.decode(latents).sample

examples/community/pipeline_prompt2prompt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ def __call__(
254254
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
255255
progress_bar.update()
256256
if callback is not None and i % callback_steps == 0:
257-
callback(i, t, latents)
257+
step_idx = i // getattr(self.scheduler, "order", 1)
258+
callback(step_idx, t, latents)
258259

259260
# 8. Post-processing
260261
if not output_type == "latent":

examples/community/pipeline_zero1to3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,8 @@ def __call__(
865865
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
866866
progress_bar.update()
867867
if callback is not None and i % callback_steps == 0:
868-
callback(i, t, latents)
868+
step_idx = i // getattr(self.scheduler, "order", 1)
869+
callback(step_idx, t, latents)
869870

870871
# 8. Post-processing
871872
has_nsfw_concept = None

0 commit comments

Comments
 (0)