Skip to content

SDXLCFGCutoffCallback does not work with StableDiffusionXLControlNetPipeline #8686

Closed
@rootonchair

Description

@rootonchair

Describe the bug

Running CFGCutoffCallback with ControlNet SDXL will raise following error

diffusers/src/diffusers/models/attention.py:372, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
    364         norm_hidden_states = self.pos_embed(norm_hidden_states)
    366     attn_output = self.attn2(
    367         norm_hidden_states,
    368         encoder_hidden_states=encoder_hidden_states,
    369         attention_mask=encoder_attention_mask,
    370         **cross_attention_kwargs,
    371     )
--> 372     hidden_states = attn_output + hidden_states
    374 # 4. Feed-forward
    375 # i2vgen doesn't have this norm 🤷‍♂️
    376 if self.norm_type == "ada_norm_continuous":

RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1

which occurs due to conditional image (https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1488) is not converted back to batch 1.

So the solution would be either adding new Callback for ControlNet or fixing current Callback to convert image back to shape 1

Reproduction

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
from diffusers.callbacks import SDXLCFGCutoffCallback
from diffusers.utils import load_image, make_image_grid
from PIL import Image
import cv2
import numpy as np
import torch

original_image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)

image = np.array(original_image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    use_safetensors=True
)
pipe.enable_model_cpu_offload()

prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = 'low quality, bad quality, sketches'
callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    image=canny_image,
    controlnet_conditioning_scale=0.5,
    callback_on_step_end=callback,
).images[0]
make_image_grid([original_image, canny_image, image], rows=1, cols=3)

Logs

No response

System Info

  • 🤗 Diffusers version: 0.29.0.dev0
  • Platform: Linux-4.18.0-408.el8.x86_64-x86_64-with-glibc2.17
  • Running on a notebook?: No
  • Running on Google Colab?: No
  • Python version: 3.8.13
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.23.4
  • Transformers version: 4.40.0.dev0
  • Accelerate version: 0.28.0
  • PEFT version: 0.11.1
  • Bitsandbytes version: 0.43.1
  • Safetensors version: 0.4.2
  • xFormers version: 0.0.22.post7
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@sayakpaul @yiyixuxu

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions