|
23 | 23 | from ...loaders import TextualInversionLoaderMixin |
24 | 24 | from ...models import AutoencoderKL, UNet2DConditionModel |
25 | 25 | from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers |
26 | | -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor |
| 26 | +from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor |
27 | 27 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
28 | 28 |
|
29 | 29 |
|
@@ -129,10 +129,36 @@ def enable_sequential_cpu_offload(self, gpu_id=0): |
129 | 129 |
|
130 | 130 | device = torch.device(f"cuda:{gpu_id}") |
131 | 131 |
|
132 | | - for cpu_offloaded_model in [self.unet, self.text_encoder]: |
| 132 | + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: |
133 | 133 | if cpu_offloaded_model is not None: |
134 | 134 | cpu_offload(cpu_offloaded_model, device) |
135 | 135 |
|
| 136 | + def enable_model_cpu_offload(self, gpu_id=0): |
| 137 | + r""" |
| 138 | + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared |
| 139 | + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` |
| 140 | + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with |
| 141 | + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. |
| 142 | + """ |
| 143 | + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): |
| 144 | + from accelerate import cpu_offload_with_hook |
| 145 | + else: |
| 146 | + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") |
| 147 | + |
| 148 | + device = torch.device(f"cuda:{gpu_id}") |
| 149 | + |
| 150 | + if self.device.type != "cpu": |
| 151 | + self.to("cpu", silence_dtype_warnings=True) |
| 152 | + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) |
| 153 | + |
| 154 | + hook = None |
| 155 | + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: |
| 156 | + if cpu_offloaded_model is not None: |
| 157 | + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) |
| 158 | + |
| 159 | + # We'll offload the last model manually. |
| 160 | + self.final_offload_hook = hook |
| 161 | + |
136 | 162 | @property |
137 | 163 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device |
138 | 164 | def _execution_device(self): |
@@ -647,6 +673,10 @@ def __call__( |
647 | 673 | self.vae.to(dtype=torch.float32) |
648 | 674 | image = self.decode_latents(latents.float()) |
649 | 675 |
|
| 676 | + # Offload last model to CPU |
| 677 | + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| 678 | + self.final_offload_hook.offload() |
| 679 | + |
650 | 680 | # 11. Convert to PIL |
651 | 681 | if output_type == "pil": |
652 | 682 | image = self.numpy_to_pil(image) |
|
0 commit comments