Skip to content

Commit 3ba36f9

Browse files
[SD-XL] Fix sdxl controlnet inference (huggingface#4238)
* Fix controlnet xl inference * correct some sd xl control inference
1 parent b288684 commit 3ba36f9

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,19 @@ def enable_model_cpu_offload(self, gpu_id=0):
180180

181181
device = torch.device(f"cuda:{gpu_id}")
182182

183+
if self.device.type != "cpu":
184+
self.to("cpu", silence_dtype_warnings=True)
185+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
186+
187+
model_sequence = (
188+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
189+
)
190+
model_sequence.extend([self.unet, self.vae])
191+
183192
hook = None
184-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
193+
for cpu_offloaded_model in model_sequence:
185194
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
186195

187-
# control net hook has be manually offloaded as it alternates with unet
188196
cpu_offload_with_hook(self.controlnet, device)
189197

190198
# We'll offload the last model manually.
@@ -639,7 +647,7 @@ def __call__(
639647
height: Optional[int] = None,
640648
width: Optional[int] = None,
641649
num_inference_steps: int = 50,
642-
guidance_scale: float = 7.5,
650+
guidance_scale: float = 5.0,
643651
negative_prompt: Optional[Union[str, List[str]]] = None,
644652
negative_prompt_2: Optional[Union[str, List[str]]] = None,
645653
num_images_per_prompt: Optional[int] = 1,
@@ -657,9 +665,9 @@ def __call__(
657665
guess_mode: bool = False,
658666
control_guidance_start: Union[float, List[float]] = 0.0,
659667
control_guidance_end: Union[float, List[float]] = 1.0,
660-
original_size: Tuple[int, int] = (1024, 1024),
668+
original_size: Tuple[int, int] = None,
661669
crops_coords_top_left: Tuple[int, int] = (0, 0),
662-
target_size: Tuple[int, int] = (1024, 1024),
670+
target_size: Tuple[int, int] = None,
663671
):
664672
r"""
665673
Function invoked when calling the pipeline for generation.
@@ -875,6 +883,9 @@ def __call__(
875883
]
876884
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
877885

886+
original_size = original_size or image.shape[-2:]
887+
target_size = target_size or (height, width)
888+
878889
# 7.2 Prepare added time ids & embeddings
879890
add_text_embeds = pooled_prompt_embeds
880891
add_time_ids = self._get_add_time_ids(

tests/pipelines/controlnet/test_controlnet_sdxl.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
)
2929
from diffusers.utils import randn_tensor, torch_device
3030
from diffusers.utils.import_utils import is_xformers_available
31-
from diffusers.utils.testing_utils import (
32-
enable_full_determinism,
33-
)
31+
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
3432

3533
from ..pipeline_params import (
3634
IMAGE_TO_IMAGE_IMAGE_PARAMS,
@@ -125,10 +123,10 @@ def get_dummy_components(self):
125123
projection_dim=32,
126124
)
127125
text_encoder = CLIPTextModel(text_encoder_config)
128-
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
126+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
129127

130128
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
131-
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
129+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
132130

133131
components = {
134132
"unet": unet,
@@ -179,6 +177,35 @@ def test_xformers_attention_forwardGenerator_pass(self):
179177
def test_inference_batch_single_identical(self):
180178
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
181179

180+
@require_torch_gpu
181+
def test_stable_diffusion_xl_offloads(self):
182+
pipes = []
183+
components = self.get_dummy_components()
184+
sd_pipe = self.pipeline_class(**components).to(torch_device)
185+
pipes.append(sd_pipe)
186+
187+
components = self.get_dummy_components()
188+
sd_pipe = self.pipeline_class(**components)
189+
sd_pipe.enable_model_cpu_offload()
190+
pipes.append(sd_pipe)
191+
192+
components = self.get_dummy_components()
193+
sd_pipe = self.pipeline_class(**components)
194+
sd_pipe.enable_sequential_cpu_offload()
195+
pipes.append(sd_pipe)
196+
197+
image_slices = []
198+
for pipe in pipes:
199+
pipe.unet.set_default_attn_processor()
200+
201+
inputs = self.get_dummy_inputs(torch_device)
202+
image = pipe(**inputs).images
203+
204+
image_slices.append(image[0, -3:, -3:, -1].flatten())
205+
206+
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
207+
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
208+
182209
def test_stable_diffusion_xl_multi_prompts(self):
183210
components = self.get_dummy_components()
184211
sd_pipe = self.pipeline_class(**components).to(torch_device)

0 commit comments

Comments
 (0)