Skip to content

Commit 9c85611

Browse files
Add model offload to x4 upscaler (huggingface#3187)
* Add model offload to x4 upscaler * fix
1 parent 9bce375 commit 9c85611

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...loaders import TextualInversionLoaderMixin
2424
from ...models import AutoencoderKL, UNet2DConditionModel
2525
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
26-
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
26+
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
2727
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2828

2929

@@ -129,10 +129,36 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
129129

130130
device = torch.device(f"cuda:{gpu_id}")
131131

132-
for cpu_offloaded_model in [self.unet, self.text_encoder]:
132+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
133133
if cpu_offloaded_model is not None:
134134
cpu_offload(cpu_offloaded_model, device)
135135

136+
def enable_model_cpu_offload(self, gpu_id=0):
137+
r"""
138+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
139+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
140+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
141+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
142+
"""
143+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
144+
from accelerate import cpu_offload_with_hook
145+
else:
146+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
147+
148+
device = torch.device(f"cuda:{gpu_id}")
149+
150+
if self.device.type != "cpu":
151+
self.to("cpu", silence_dtype_warnings=True)
152+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
153+
154+
hook = None
155+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
156+
if cpu_offloaded_model is not None:
157+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
158+
159+
# We'll offload the last model manually.
160+
self.final_offload_hook = hook
161+
136162
@property
137163
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
138164
def _execution_device(self):
@@ -647,6 +673,10 @@ def __call__(
647673
self.vae.to(dtype=torch.float32)
648674
image = self.decode_latents(latents.float())
649675

676+
# Offload last model to CPU
677+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
678+
self.final_offload_hook.offload()
679+
650680
# 11. Convert to PIL
651681
if output_type == "pil":
652682
image = self.numpy_to_pil(image)

0 commit comments

Comments
 (0)