Skip to content

Commit 8ca179a

Browse files
authored
Update free model hooks (huggingface#5680)
update free model hooks
1 parent 71f56c7 commit 8ca179a

16 files changed

+38
-16
lines changed

src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,8 +1109,6 @@ def __call__(
11091109
nsfw_detected = None
11101110
watermark_detected = None
11111111

1112-
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
1113-
self.unet_offload_hook.offload()
11141112
else:
11151113
# 10. Post-processing
11161114
image = (image / 2 + 0.5).clamp(0, 1)
@@ -1119,9 +1117,7 @@ def __call__(
11191117
# 11. Run safety checker
11201118
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
11211119

1122-
# Offload last model to CPU
1123-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1124-
self.final_offload_hook.offload()
1120+
self.maybe_free_model_hooks()
11251121

11261122
if not return_dict:
11271123
return (image, nsfw_detected, watermark_detected)

src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,8 @@ def __call__(
388388
# post-processing
389389
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
390390

391+
self.maybe_free_model_hooks()
392+
391393
if output_type not in ["pt", "np", "pil"]:
392394
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
393395

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def __call__(
321321
callback_steps=callback_steps,
322322
return_dict=return_dict,
323323
)
324+
325+
self.maybe_free_model_hooks()
326+
324327
return outputs
325328

326329

@@ -558,6 +561,9 @@ def __call__(
558561
callback_steps=callback_steps,
559562
return_dict=return_dict,
560563
)
564+
565+
self.maybe_free_model_hooks()
566+
561567
return outputs
562568

563569

@@ -593,7 +599,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
593599
"""
594600

595601
_load_connected_pipes = True
596-
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
602+
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
597603

598604
def __init__(
599605
self,
@@ -802,4 +808,7 @@ def __call__(
802808
callback_steps=callback_steps,
803809
return_dict=return_dict,
804810
)
811+
812+
self.maybe_free_model_hooks()
813+
805814
return outputs

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ def __call__(
481481
# 7. post-processing
482482
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
483483

484+
self.maybe_free_model_hooks()
485+
484486
if output_type not in ["pt", "np", "pil"]:
485487
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
486488

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ def __call__(
616616
# post-processing
617617
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
618618

619+
self.maybe_free_model_hooks()
620+
619621
if output_type not in ["pt", "np", "pil"]:
620622
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
621623

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def __call__(
527527
if negative_prompt is None:
528528
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
529529

530-
self.maybe_free_model_hooks
530+
self.maybe_free_model_hooks()
531531
else:
532532
image_embeddings, zero_embeds = image_embeddings.chunk(2)
533533

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def __call__(
326326
callback_on_step_end=callback_on_step_end,
327327
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
328328
)
329+
self.maybe_free_model_hooks()
330+
329331
return outputs
330332

331333

@@ -572,6 +574,8 @@ def __call__(
572574
callback_on_step_end=callback_on_step_end,
573575
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
574576
)
577+
578+
self.maybe_free_model_hooks()
575579
return outputs
576580

577581

@@ -842,4 +846,6 @@ def __call__(
842846
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
843847
**kwargs,
844848
)
849+
self.maybe_free_model_hooks()
850+
845851
return outputs

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,10 @@ def __call__(
531531
# if negative prompt has been defined, we retrieve split the image embedding into two
532532
if negative_prompt is None:
533533
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
534-
535-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
536-
self.final_offload_hook.offload()
537534
else:
538535
image_embeddings, zero_embeds = image_embeddings.chunk(2)
539536

540-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
541-
self.prior_hook.offload()
537+
self.maybe_free_model_hooks()
542538

543539
if output_type not in ["pt", "np"]:
544540
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,10 @@ def __call__(
545545
# if negative prompt has been defined, we retrieve split the image embedding into two
546546
if negative_prompt is None:
547547
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
548-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
549-
self.final_offload_hook.offload()
550548
else:
551549
image_embeddings, zero_embeds = image_embeddings.chunk(2)
552-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
553-
self.prior_hook.offload()
550+
551+
self.maybe_free_model_hooks()
554552

555553
if output_type not in ["pt", "np"]:
556554
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,7 @@ def __call__(
918918
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
919919

920920
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
921+
self.maybe_free_model_hooks()
921922

922923
if not return_dict:
923924
return (image, has_nsfw_concept)

0 commit comments

Comments
 (0)