Skip to content

AutoencoderDC.encode fails with torch.compile(fullgraph=True) - "name 'torch' is not defined" #11864

Open
@SingleBicycle

Description

@SingleBicycle

Describe the bug

I'm trying to optimize my data preprocessing pipeline for the Sana model by using torch.compile on the DC-AE encoder. Following PyTorch's best practices, I attempted to compile only the encode method with fullgraph=True for better performance, but I'm encountering an error.

When I try:

dae.encode = torch.compile(dae.encode, fullgraph=True)

The code fails with NameError: name 'torch' is not defined when calling dae.encode(x).

However, compiling the entire model works:

dae = torch.compile(dae, fullgraph=True)

I'm unsure if this is expected behavior or if I'm doing something wrong. Is there a recommended way to compile just the encode method for AutoencoderDC?

I was advised to use the more targeted approach of compiling only the encode method for better performance, but it seems like the DC-AE model might have some internal structure that prevents this optimization pattern.

Any guidance on the correct way to apply torch.compile optimizations to AutoencoderDC would be greatly appreciated. Should I stick with compiling the entire model, or is there a way to make method-level compilation work?

Reproduction

import torch
from diffusers import AutoencoderDC

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dae = AutoencoderDC.from_pretrained(
    "mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers",
    torch_dtype=torch.bfloat16
).to(device).eval()

# This fails with "name 'torch' is not defined"
dae.encode = torch.compile(dae.encode, fullgraph=True)

# Test
x = torch.randn(1, 3, 512, 512, device=device, dtype=torch.bfloat16)
out = dae.encode(x)  # Error occurs here
# This works fine
dae = torch.compile(dae, fullgraph=True)

Logs

Testing torch.compile(dae.encode, fullgraph=True)
/data1/tzz/anaconda_dir/envs/Sana/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:150: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
  ✗ Error: name 'torch' is not defined

System Info

  • 🤗 Diffusers version: 0.34.0.dev0
  • Platform: Linux-5.15.0-142-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.18
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.0
  • Transformers version: 4.45.2
  • Accelerate version: 1.7.0
  • PEFT version: 0.15.2
  • Bitsandbytes version: 0.46.0
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.27.post2
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions