Skip to content

CPU Offloading Inefficiency: UNet/Transformer Repeatedly Moved Between CPU/GPU During Loops #11872

Open
@JoeGaffney

Description

@JoeGaffney

Describe the bug

CPU Offloading Inefficiency: UNet/Transformer Repeatedly Moved Between CPU/GPU During Denoising Loop

Issue

The current CPU offloading mechanism (enable_model_cpu_offload()) in Diffusers is inefficient for models with denoising loops. It repeatedly moves the transformer/UNet component between CPU and GPU for each step, creating unnecessary data transfer overhead and potentially slowing down inference or causing uneeded back and forth transfer.

Observed Behavior

After enabling CPU offloading with pipe.enable_model_cpu_offload(), the transformer/UNet is:

  1. Loaded to GPU for a single denoising step
  2. Immediately moved back to CPU after use
  3. Loaded to GPU again for the next step
  4. Repeat for every step in the denoising process

This pattern appears in both image models (SD3, FLUX) and video models (WAN, LTX), but the performance impact is more severe for video models due to their larger transformer sizes and 3D tensors.

Log Evidence

✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
[...repeats for each denoising step...]

Performance Impact

  • Each GPU↔CPU transfer adds significant latency (milliseconds to seconds depending on model size)
  • Maybe data is still in the current cuda cache so some performance is mitigated?

Proposed Solutions

  1. Add an exclude_from_offload parameter to enable_model_cpu_offload():
pipe.enable_model_cpu_offload(exclude_from_offload=["transformer"])
  1. Make the offloading mechanism smarter about components used in loops or used sequentially

  2. Add a keep_on_device_during_generation parameter to control this behavior

The issue is evident in the core loop of diffusion pipelines (e.g., in WanImageToVideoPipeline.call) where the transformer is called repeatedly for each timestep but gets moved back to CPU after each use.

Reproduction

Monkey patch accelerate.hooks and you still see this pattern in most model pipelines - or at least pretty common ones I've tried.

import accelerate.hooks

# Save the original methods
original_pre_forward = accelerate.hooks.CpuOffload.pre_forward
original_post_forward = accelerate.hooks.CpuOffload.post_forward


# Create logging wrappers
def logged_pre_forward(self, module, *args, **kwargs):
    module_name = module.__class__.__name__
    result = original_pre_forward(self, module, *args, **kwargs)
    print(f"✅ Moved {module_name} to {self.execution_device}")
    return result


def logged_post_forward(self, module, output):
    module_name = module.__class__.__name__
    result = original_post_forward(self, module, output)
    print(f"🔄 Moved {module_name} back to CPU")
    return result


# Apply the monkey patch
accelerate.hooks.CpuOffload.pre_forward = logged_pre_forward
accelerate.hooks.CpuOffload.post_forward = logged_post_forward

Apply on any diffusers pipe that supports offloading.

pipe.enable_model_cpu_offload()

Logs

-------------------------------------------------------------------------------------------------- live log call --------------------------------------------------------------------------------------------------INFO     common.logger:utils.py:85 Image loaded from Base64 bytes, size: (1897, 788)
INFO     common.logger:utils.py:74 Image Resized from: (1897, 788) to 1897x788
WARNING  common.logger:pipeline_helpers.py:100 Evicting all 2 models from cache
INFO     common.logger:pipeline_helpers.py:228 Loading quantized model from /WORKSPACE/quantized/black-forest-labs/FLUX.1-dev/8bit/transformer
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.48it/s]INFO     common.logger:utils.py:151 Calling get_quantized_model with args: (), kwargs: {'model_id': 'black-forest-labs/FLUX.1-dev', 'subfolder': 'transformer', 'model_class': <class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'>, 'target_precision': 8, 'torch_dtype': torch.bfloat16}, took: 1.44s
INFO     common.logger:pipeline_helpers.py:228 Loading quantized model from /WORKSPACE/quantized/black-forest-labs/FLUX.1-schnell/8bit/text_encoder_2
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 12.12it/s]Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 12.03it/s]WARNING  common.logger:pipeline_helpers.py:77 Cache miss for ('get_pipeline', PipelineConfig(model_id='black-forest-labs/FLUX.1-dev', model_family='flux', ip_adapter_models=(), ip_adapter_subfolders=(), ip_adapter_weights=(), ip_adapter_image_encoder_model='', ip_adapter_image_encoder_subfolder='')) - took: 2.93s - Cache size: 1/2
INFO     common.logger:flux.py:128 Image to image call {'width': 1897, 'height': 788, 'prompt': 'Change to night time and add rain and lighting', 'negative_prompt': 'worst quality, inconsistent motion, blurry, jittery, distorted', 'image': <PIL.Image.Image image mode=RGB size=1897x788 at 0x7FB9D0F31650>, 'num_inference_steps': 10, 'generator': <torch._C.Generator object at 0x7fbc5ead1410>, 'strength': 0.5, 'guidance_scale': 3.5}
`height` and `width` have to be divisible by 16 but are 788 and 1897. Dimensions will be resized accordingly
✅ Moved CLIPTextModel to cuda:0
🔄 Moved CLIPTextModel back to CPU
✅ Moved T5EncoderModel to cuda:0
🔄 Moved T5EncoderModel back to CPU
✅ Moved AutoencoderKL to cuda:0
  0%|                                                                                                                                                                                        | 0/5 [00:00<?, ?it/s]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
 20%|███████████████████████████████████▏                                                                                                                                            | 1/5 [00:23<01:32, 23.13s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
 40%|██████████████████████████████████████████████████████████████████████▍                                                                                                         | 2/5 [00:25<00:32, 10.99s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
 60%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                      | 3/5 [00:28<00:14,  7.11s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                   | 4/5 [00:30<00:05,  5.30s/it]
✅ Moved FluxTransformer2DModel to cuda:0
🔄 Moved FluxTransformer2DModel back to CPU
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.63s/it]
✅ Moved AutoencoderKL to cuda:0
WARNING  common.logger:memory.py:58 GPU Memory Usage: 4.91GB / 25.77GB,  Reserved: 4.29GB, Allocated: 0.51GB, Usage: 19.05%
PASSED
tests/images/test_image_to_image.py::test_image_to_image[flux-kontext-1-image_to_image] Base64: /9j/4AAQSkZJRgABAQEAkACQAAD/4QLcRXhpZgAATU0AKgAAAAgABAE7AAIAAAAGAAABSodpAAQAAAABAAABUJydAAEAAAAMAAAC.

System Info

  • 🤗 Diffusers version: 0.32.2
  • Platform: Windows-10-10.0.19045-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.7
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.30.1
  • Transformers version: 4.51.0
  • Accelerate version: 1.6.0
  • PEFT version: not installed
  • Bitsandbytes version: 0.45.5
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
  • Using GPU in script?: Docker mostly
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions