Description
Describe the bug
CPU Offloading Inefficiency: UNet/Transformer Repeatedly Moved Between CPU/GPU During Denoising Loop
Issue
The current CPU offloading mechanism (enable_model_cpu_offload()
) in Diffusers is inefficient for models with denoising loops. It repeatedly moves the transformer/UNet component between CPU and GPU for each step, creating unnecessary data transfer overhead and potentially slowing down inference or causing uneeded back and forth transfer.
Observed Behavior
After enabling CPU offloading with pipe.enable_model_cpu_offload()
, the transformer/UNet is:
- Loaded to GPU for a single denoising step
- Immediately moved back to CPU after use
- Loaded to GPU again for the next step
- Repeat for every step in the denoising process
This pattern appears in both image models (SD3, FLUX) and video models (WAN, LTX), but the performance impact is more severe for video models due to their larger transformer sizes and 3D tensors.
Log Evidence
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
[...repeats for each denoising step...]
Performance Impact
- Each GPU↔CPU transfer adds significant latency (milliseconds to seconds depending on model size)
- Maybe data is still in the current cuda cache so some performance is mitigated?
Proposed Solutions
- Add an
exclude_from_offload
parameter toenable_model_cpu_offload()
:
pipe.enable_model_cpu_offload(exclude_from_offload=["transformer"])
-
Make the offloading mechanism smarter about components used in loops or used sequentially
-
Add a
keep_on_device_during_generation
parameter to control this behavior
The issue is evident in the core loop of diffusion pipelines (e.g., in WanImageToVideoPipeline.call) where the transformer is called repeatedly for each timestep but gets moved back to CPU after each use.
Reproduction
Monkey patch accelerate.hooks and you still see this pattern in most model pipelines - or at least pretty common ones I've tried.
import accelerate.hooks
# Save the original methods
original_pre_forward = accelerate.hooks.CpuOffload.pre_forward
original_post_forward = accelerate.hooks.CpuOffload.post_forward
# Create logging wrappers
def logged_pre_forward(self, module, *args, **kwargs):
module_name = module.__class__.__name__
result = original_pre_forward(self, module, *args, **kwargs)
print(f"✅ Moved {module_name} to {self.execution_device}")
return result
def logged_post_forward(self, module, output):
module_name = module.__class__.__name__
result = original_post_forward(self, module, output)
print(f"🔄 Moved {module_name} back to CPU")
return result
# Apply the monkey patch
accelerate.hooks.CpuOffload.pre_forward = logged_pre_forward
accelerate.hooks.CpuOffload.post_forward = logged_post_forward
Apply on any diffusers pipe that supports offloading.
pipe.enable_model_cpu_offload()
Logs
-------------------------------------------------------------------------------------------------- live log call --------------------------------------------------------------------------------------------------INFO common.logger:utils.py:85 Image loaded from Base64 bytes, size: (1897, 788)
INFO common.logger:utils.py:74 Image Resized from: (1897, 788) to 1897x788
WARNING common.logger:pipeline_helpers.py:100 Evicting all 2 models from cache
INFO common.logger:pipeline_helpers.py:228 Loading quantized model from /WORKSPACE/quantized/black-forest-labs/FLUX.1-dev/8bit/transformer
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.48it/s]INFO common.logger:utils.py:151 Calling get_quantized_model with args: (), kwargs: {'model_id': 'black-forest-labs/FLUX.1-dev', 'subfolder': 'transformer', 'model_class': <class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'>, 'target_precision': 8, 'torch_dtype': torch.bfloat16}, took: 1.44s
INFO common.logger:pipeline_helpers.py:228 Loading quantized model from /WORKSPACE/quantized/black-forest-labs/FLUX.1-schnell/8bit/text_encoder_2
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 12.12it/s]Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 12.03it/s]WARNING common.logger:pipeline_helpers.py:77 Cache miss for ('get_pipeline', PipelineConfig(model_id='black-forest-labs/FLUX.1-dev', model_family='flux', ip_adapter_models=(), ip_adapter_subfolders=(), ip_adapter_weights=(), ip_adapter_image_encoder_model='', ip_adapter_image_encoder_subfolder='')) - took: 2.93s - Cache size: 1/2
INFO common.logger:flux.py:128 Image to image call {'width': 1897, 'height': 788, 'prompt': 'Change to night time and add rain and lighting', 'negative_prompt': 'worst quality, inconsistent motion, blurry, jittery, distorted', 'image': <PIL.Image.Image image mode=RGB size=1897x788 at 0x7FB9D0F31650>, 'num_inference_steps': 10, 'generator': <torch._C.Generator object at 0x7fbc5ead1410>, 'strength': 0.5, 'guidance_scale': 3.5}
`height` and `width` have to be divisible by 16 but are 788 and 1897. Dimensions will be resized accordingly
✅ Moved CLIPTextModel to cuda:0
🔄 Moved CLIPTextModel back to CPU
✅ Moved T5EncoderModel to cuda:0
🔄 Moved T5EncoderModel back to CPU
✅ Moved AutoencoderKL to cuda:0
0%| | 0/5 [00:00<?, ?it/s]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
20%|███████████████████████████████████▏ | 1/5 [00:23<01:32, 23.13s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
40%|██████████████████████████████████████████████████████████████████████▍ | 2/5 [00:25<00:32, 10.99s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
60%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 3/5 [00:28<00:14, 7.11s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 4/5 [00:30<00:05, 5.30s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00, 6.63s/it]
✅ Moved AutoencoderKL to cuda:0
WARNING common.logger:memory.py:58 GPU Memory Usage: 4.91GB / 25.77GB, Reserved: 4.29GB, Allocated: 0.51GB, Usage: 19.05%
PASSED
tests/images/test_image_to_image.py::test_image_to_image[flux-kontext-1-image_to_image] Base64: /9j/4AAQSkZJRgABAQEAkACQAAD/4QLcRXhpZgAATU0AKgAAAAgABAE7AAIAAAAGAAABSodpAAQAAAABAAABUJydAAEAAAAMAAAC.
System Info
- 🤗 Diffusers version: 0.32.2
- Platform: Windows-10-10.0.19045-SP0
- Running on Google Colab?: No
- Python version: 3.12.7
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.1
- Transformers version: 4.51.0
- Accelerate version: 1.6.0
- PEFT version: not installed
- Bitsandbytes version: 0.45.5
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
- Using GPU in script?: Docker mostly
- Using distributed or parallel set-up in script?: No
Who can help?
No response