Skip to content

Commit 9357965

Browse files
Refactor model offload (huggingface#4514)
* [Draft] Refactor model offload * [Draft] Refactor model offload * Apply suggestions from code review * cpu offlaod updates * remove model cpu offload from individual pipelines * add hook to offload models to cpu * clean up * model offload * add model cpu offload string * make style * clean up * fixes for offload issues * fix tests issues * resolve merge conflicts * update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * make style * Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 16a056a commit 9357965

File tree

85 files changed

+370
-1822
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+370
-1822
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from packaging import version
2020
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
2121

22-
from diffusers.utils import is_accelerate_available, is_accelerate_version
23-
2422
from ...configuration_utils import FrozenDict
2523
from ...image_processor import VaeImageProcessor
2624
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
@@ -100,6 +98,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
10098
feature_extractor ([`~transformers.CLIPImageProcessor`]):
10199
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
102100
"""
101+
model_cpu_offload_seq = "text_encoder->unet->vae"
103102
_optional_components = ["safety_checker", "feature_extractor"]
104103

105104
def __init__(
@@ -221,34 +220,6 @@ def disable_vae_tiling(self):
221220
"""
222221
self.vae.disable_tiling()
223222

224-
def enable_model_cpu_offload(self, gpu_id=0):
225-
r"""
226-
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
227-
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
228-
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
229-
iterative execution of the `unet`.
230-
"""
231-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
232-
from accelerate import cpu_offload_with_hook
233-
else:
234-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
235-
236-
device = torch.device(f"cuda:{gpu_id}")
237-
238-
if self.device.type != "cpu":
239-
self.to("cpu", silence_dtype_warnings=True)
240-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
241-
242-
hook = None
243-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
244-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
245-
246-
if self.safety_checker is not None:
247-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
248-
249-
# We'll offload the last model manually.
250-
self.final_offload_hook = hook
251-
252223
def _encode_prompt(
253224
self,
254225
prompt,
@@ -750,9 +721,8 @@ def __call__(
750721

751722
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
752723

753-
# Offload last model to CPU
754-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
755-
self.final_offload_hook.offload()
724+
# Offload all models
725+
self.maybe_free_model_hooks()
756726

757727
if not return_dict:
758728
return (image, has_nsfw_concept)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from packaging import version
2222
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
2323

24-
from diffusers.utils import is_accelerate_available, is_accelerate_version
25-
2624
from ...configuration_utils import FrozenDict
2725
from ...image_processor import PipelineImageInput, VaeImageProcessor
2826
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
@@ -127,6 +125,7 @@ class AltDiffusionImg2ImgPipeline(
127125
feature_extractor ([`~transformers.CLIPImageProcessor`]):
128126
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
129127
"""
128+
model_cpu_offload_seq = "text_encoder->unet->vae"
130129
_optional_components = ["safety_checker", "feature_extractor"]
131130

132131
def __init__(
@@ -219,34 +218,6 @@ def __init__(
219218
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
220219
self.register_to_config(requires_safety_checker=requires_safety_checker)
221220

222-
def enable_model_cpu_offload(self, gpu_id=0):
223-
r"""
224-
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
225-
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
226-
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
227-
iterative execution of the `unet`.
228-
"""
229-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
230-
from accelerate import cpu_offload_with_hook
231-
else:
232-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
233-
234-
device = torch.device(f"cuda:{gpu_id}")
235-
236-
if self.device.type != "cpu":
237-
self.to("cpu", silence_dtype_warnings=True)
238-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
239-
240-
hook = None
241-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
242-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
243-
244-
if self.safety_checker is not None:
245-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
246-
247-
# We'll offload the last model manually.
248-
self.final_offload_hook = hook
249-
250221
def _encode_prompt(
251222
self,
252223
prompt,
@@ -773,9 +744,8 @@ def __call__(
773744

774745
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
775746

776-
# Offload last model to CPU
777-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
778-
self.final_offload_hook.offload()
747+
# Offload all models
748+
self.maybe_free_model_hooks()
779749

780750
if not return_dict:
781751
return (image, has_nsfw_concept)

src/diffusers/pipelines/audioldm/pipeline_audioldm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class AudioLDMPipeline(DiffusionPipeline):
7272
vocoder ([`~transformers.SpeechT5HifiGan`]):
7373
Vocoder of class `SpeechT5HifiGan`.
7474
"""
75+
model_cpu_offload_seq = "text_encoder->unet->vae"
7576

7677
def __init__(
7778
self,

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,8 @@ def __call__(
947947
if callback is not None and i % callback_steps == 0:
948948
callback(i, t, latents)
949949

950+
self.maybe_free_model_hooks()
951+
950952
# 8. Post-processing
951953
if not output_type == "latent":
952954
latents = 1 / self.vae.config.scaling_factor * latents

src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from ...models import UNet2DModel
66
from ...schedulers import CMStochasticIterativeScheduler
77
from ...utils import (
8-
is_accelerate_available,
9-
is_accelerate_version,
108
logging,
119
replace_example_docstring,
1210
)
@@ -62,6 +60,7 @@ class ConsistencyModelPipeline(DiffusionPipeline):
6260
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
6361
compatible with [`CMStochasticIterativeScheduler`].
6462
"""
63+
model_cpu_offload_seq = "unet"
6564

6665
def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
6766
super().__init__()
@@ -73,34 +72,6 @@ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler)
7372

7473
self.safety_checker = None
7574

76-
def enable_model_cpu_offload(self, gpu_id=0):
77-
r"""
78-
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
79-
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
80-
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
81-
iterative execution of the `unet`.
82-
"""
83-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
84-
from accelerate import cpu_offload_with_hook
85-
else:
86-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
87-
88-
device = torch.device(f"cuda:{gpu_id}")
89-
90-
if self.device.type != "cpu":
91-
self.to("cpu", silence_dtype_warnings=True)
92-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
93-
94-
hook = None
95-
for cpu_offloaded_model in [self.unet]:
96-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
97-
98-
if self.safety_checker is not None:
99-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
100-
101-
# We'll offload the last model manually.
102-
self.final_offload_hook = hook
103-
10475
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
10576
shape = (batch_size, num_channels, height, width)
10677
if isinstance(generator, list) and len(generator) != batch_size:
@@ -280,9 +251,8 @@ def __call__(
280251
# 6. Post-process image sample
281252
image = self.postprocess_image(sample, output_type=output_type)
282253

283-
# Offload last model to CPU
284-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
285-
self.final_offload_hook.offload()
254+
# Offload all models
255+
self.maybe_free_model_hooks()
286256

287257
if not return_dict:
288258
return (image,)

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import (
3131
deprecate,
32-
is_accelerate_available,
33-
is_accelerate_version,
3432
logging,
3533
replace_example_docstring,
3634
)
@@ -125,6 +123,7 @@ class StableDiffusionControlNetPipeline(
125123
feature_extractor ([`~transformers.CLIPImageProcessor`]):
126124
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
127125
"""
126+
model_cpu_offload_seq = "text_encoder->unet->vae"
128127
_optional_components = ["safety_checker", "feature_extractor"]
129128

130129
def __init__(
@@ -210,34 +209,6 @@ def disable_vae_tiling(self):
210209
"""
211210
self.vae.disable_tiling()
212211

213-
def enable_model_cpu_offload(self, gpu_id=0):
214-
r"""
215-
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
216-
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
217-
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
218-
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
219-
"""
220-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
221-
from accelerate import cpu_offload_with_hook
222-
else:
223-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
224-
225-
device = torch.device(f"cuda:{gpu_id}")
226-
227-
hook = None
228-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
229-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
230-
231-
if self.safety_checker is not None:
232-
# the safety checker can offload the vae again
233-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
234-
235-
# control net hook has be manually offloaded as it alternates with unet
236-
cpu_offload_with_hook(self.controlnet, device)
237-
238-
# We'll offload the last model manually.
239-
self.final_offload_hook = hook
240-
241212
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
242213
def _encode_prompt(
243214
self,
@@ -1031,9 +1002,8 @@ def __call__(
10311002

10321003
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
10331004

1034-
# Offload last model to CPU
1035-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1036-
self.final_offload_hook.offload()
1005+
# Offload all models
1006+
self.maybe_free_model_hooks()
10371007

10381008
if not return_dict:
10391009
return (image, has_nsfw_concept)

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
from ...schedulers import KarrasDiffusionSchedulers
2929
from ...utils import (
3030
deprecate,
31-
is_accelerate_available,
32-
is_accelerate_version,
3331
logging,
3432
replace_example_docstring,
3533
)
@@ -149,6 +147,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
149147
feature_extractor ([`~transformers.CLIPImageProcessor`]):
150148
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
151149
"""
150+
model_cpu_offload_seq = "text_encoder->unet->vae"
152151
_optional_components = ["safety_checker", "feature_extractor"]
153152

154153
def __init__(
@@ -234,34 +233,6 @@ def disable_vae_tiling(self):
234233
"""
235234
self.vae.disable_tiling()
236235

237-
def enable_model_cpu_offload(self, gpu_id=0):
238-
r"""
239-
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
240-
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
241-
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
242-
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
243-
"""
244-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
245-
from accelerate import cpu_offload_with_hook
246-
else:
247-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
248-
249-
device = torch.device(f"cuda:{gpu_id}")
250-
251-
hook = None
252-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
253-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
254-
255-
if self.safety_checker is not None:
256-
# the safety checker can offload the vae again
257-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
258-
259-
# control net hook has be manually offloaded as it alternates with unet
260-
cpu_offload_with_hook(self.controlnet, device)
261-
262-
# We'll offload the last model manually.
263-
self.final_offload_hook = hook
264-
265236
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
266237
def _encode_prompt(
267238
self,
@@ -1107,9 +1078,8 @@ def __call__(
11071078

11081079
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
11091080

1110-
# Offload last model to CPU
1111-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1112-
self.final_offload_hook.offload()
1081+
# Offload all models
1082+
self.maybe_free_model_hooks()
11131083

11141084
if not return_dict:
11151085
return (image, has_nsfw_concept)

0 commit comments

Comments
 (0)