Closed
Description
Describe the bug
Running CFGCutoffCallback with ControlNet SDXL will raise following error
diffusers/src/diffusers/models/attention.py:372, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
364 norm_hidden_states = self.pos_embed(norm_hidden_states)
366 attn_output = self.attn2(
367 norm_hidden_states,
368 encoder_hidden_states=encoder_hidden_states,
369 attention_mask=encoder_attention_mask,
370 **cross_attention_kwargs,
371 )
--> 372 hidden_states = attn_output + hidden_states
374 # 4. Feed-forward
375 # i2vgen doesn't have this norm 🤷♂️
376 if self.norm_type == "ada_norm_continuous":
RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1
which occurs due to conditional image
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1488) is not converted back to batch 1.
So the solution would be either adding new Callback for ControlNet or fixing current Callback to convert image
back to shape 1
Reproduction
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
from diffusers.callbacks import SDXLCFGCutoffCallback
from diffusers.utils import load_image, make_image_grid
from PIL import Image
import cv2
import numpy as np
import torch
original_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
image = np.array(original_image)
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16,
use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=True
)
pipe.enable_model_cpu_offload()
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = 'low quality, bad quality, sketches'
callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
image = pipe(
prompt,
negative_prompt=negative_prompt,
image=canny_image,
controlnet_conditioning_scale=0.5,
callback_on_step_end=callback,
).images[0]
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
Logs
No response
System Info
- 🤗 Diffusers version: 0.29.0.dev0
- Platform: Linux-4.18.0-408.el8.x86_64-x86_64-with-glibc2.17
- Running on a notebook?: No
- Running on Google Colab?: No
- Python version: 3.8.13
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.23.4
- Transformers version: 4.40.0.dev0
- Accelerate version: 0.28.0
- PEFT version: 0.11.1
- Bitsandbytes version: 0.43.1
- Safetensors version: 0.4.2
- xFormers version: 0.0.22.post7
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No