Skip to content

Commit 5a287d3

Browse files
[SDXL] Make sure multi batch prompt embeds works (huggingface#5073)
* [SDXL] Make sure multi batch prompt embeds works * [SDXL] Make sure multi batch prompt embeds works * improve more * improve more * Apply suggestions from code review
1 parent 65c162a commit 5a287d3

File tree

9 files changed

+149
-35
lines changed

9 files changed

+149
-35
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,9 @@ def encode_prompt(
314314
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
315315
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
316316

317-
if prompt is not None and isinstance(prompt, str):
318-
batch_size = 1
319-
elif prompt is not None and isinstance(prompt, list):
317+
prompt = [prompt] if isinstance(prompt, str) else prompt
318+
319+
if prompt is not None:
320320
batch_size = len(prompt)
321321
else:
322322
batch_size = prompt_embeds.shape[0]
@@ -329,6 +329,8 @@ def encode_prompt(
329329

330330
if prompt_embeds is None:
331331
prompt_2 = prompt_2 or prompt
332+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
333+
332334
# textual inversion: procecss multi-vector tokens if necessary
333335
prompt_embeds_list = []
334336
prompts = [prompt, prompt_2]
@@ -378,14 +380,18 @@ def encode_prompt(
378380
negative_prompt = negative_prompt or ""
379381
negative_prompt_2 = negative_prompt_2 or negative_prompt
380382

383+
# normalize str to list
384+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
385+
negative_prompt_2 = (
386+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
387+
)
388+
381389
uncond_tokens: List[str]
382390
if prompt is not None and type(prompt) is not type(negative_prompt):
383391
raise TypeError(
384392
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
385393
f" {type(prompt)}."
386394
)
387-
elif isinstance(negative_prompt, str):
388-
uncond_tokens = [negative_prompt, negative_prompt_2]
389395
elif batch_size != len(negative_prompt):
390396
raise ValueError(
391397
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,9 @@ def encode_prompt(
287287
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
288288
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
289289

290-
if prompt is not None and isinstance(prompt, str):
291-
batch_size = 1
292-
elif prompt is not None and isinstance(prompt, list):
290+
prompt = [prompt] if isinstance(prompt, str) else prompt
291+
292+
if prompt is not None:
293293
batch_size = len(prompt)
294294
else:
295295
batch_size = prompt_embeds.shape[0]
@@ -302,6 +302,8 @@ def encode_prompt(
302302

303303
if prompt_embeds is None:
304304
prompt_2 = prompt_2 or prompt
305+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
306+
305307
# textual inversion: procecss multi-vector tokens if necessary
306308
prompt_embeds_list = []
307309
prompts = [prompt, prompt_2]
@@ -351,14 +353,18 @@ def encode_prompt(
351353
negative_prompt = negative_prompt or ""
352354
negative_prompt_2 = negative_prompt_2 or negative_prompt
353355

356+
# normalize str to list
357+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
358+
negative_prompt_2 = (
359+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
360+
)
361+
354362
uncond_tokens: List[str]
355363
if prompt is not None and type(prompt) is not type(negative_prompt):
356364
raise TypeError(
357365
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
358366
f" {type(prompt)}."
359367
)
360-
elif isinstance(negative_prompt, str):
361-
uncond_tokens = [negative_prompt, negative_prompt_2]
362368
elif batch_size != len(negative_prompt):
363369
raise ValueError(
364370
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ def encode_prompt(
325325
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
326326
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
327327

328-
if prompt is not None and isinstance(prompt, str):
329-
batch_size = 1
330-
elif prompt is not None and isinstance(prompt, list):
328+
prompt = [prompt] if isinstance(prompt, str) else prompt
329+
330+
if prompt is not None:
331331
batch_size = len(prompt)
332332
else:
333333
batch_size = prompt_embeds.shape[0]
@@ -340,6 +340,8 @@ def encode_prompt(
340340

341341
if prompt_embeds is None:
342342
prompt_2 = prompt_2 or prompt
343+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
344+
343345
# textual inversion: procecss multi-vector tokens if necessary
344346
prompt_embeds_list = []
345347
prompts = [prompt, prompt_2]
@@ -389,14 +391,18 @@ def encode_prompt(
389391
negative_prompt = negative_prompt or ""
390392
negative_prompt_2 = negative_prompt_2 or negative_prompt
391393

394+
# normalize str to list
395+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
396+
negative_prompt_2 = (
397+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
398+
)
399+
392400
uncond_tokens: List[str]
393401
if prompt is not None and type(prompt) is not type(negative_prompt):
394402
raise TypeError(
395403
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
396404
f" {type(prompt)}."
397405
)
398-
elif isinstance(negative_prompt, str):
399-
uncond_tokens = [negative_prompt, negative_prompt_2]
400406
elif batch_size != len(negative_prompt):
401407
raise ValueError(
402408
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,9 @@ def encode_prompt(
263263
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
264264
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
265265

266-
if prompt is not None and isinstance(prompt, str):
267-
batch_size = 1
268-
elif prompt is not None and isinstance(prompt, list):
266+
prompt = [prompt] if isinstance(prompt, str) else prompt
267+
268+
if prompt is not None:
269269
batch_size = len(prompt)
270270
else:
271271
batch_size = prompt_embeds.shape[0]
@@ -278,6 +278,8 @@ def encode_prompt(
278278

279279
if prompt_embeds is None:
280280
prompt_2 = prompt_2 or prompt
281+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
282+
281283
# textual inversion: procecss multi-vector tokens if necessary
282284
prompt_embeds_list = []
283285
prompts = [prompt, prompt_2]
@@ -327,14 +329,18 @@ def encode_prompt(
327329
negative_prompt = negative_prompt or ""
328330
negative_prompt_2 = negative_prompt_2 or negative_prompt
329331

332+
# normalize str to list
333+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
334+
negative_prompt_2 = (
335+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
336+
)
337+
330338
uncond_tokens: List[str]
331339
if prompt is not None and type(prompt) is not type(negative_prompt):
332340
raise TypeError(
333341
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
334342
f" {type(prompt)}."
335343
)
336-
elif isinstance(negative_prompt, str):
337-
uncond_tokens = [negative_prompt, negative_prompt_2]
338344
elif batch_size != len(negative_prompt):
339345
raise ValueError(
340346
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ def encode_prompt(
270270
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
271271
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
272272

273-
if prompt is not None and isinstance(prompt, str):
274-
batch_size = 1
275-
elif prompt is not None and isinstance(prompt, list):
273+
prompt = [prompt] if isinstance(prompt, str) else prompt
274+
275+
if prompt is not None:
276276
batch_size = len(prompt)
277277
else:
278278
batch_size = prompt_embeds.shape[0]
@@ -285,6 +285,8 @@ def encode_prompt(
285285

286286
if prompt_embeds is None:
287287
prompt_2 = prompt_2 or prompt
288+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
289+
288290
# textual inversion: procecss multi-vector tokens if necessary
289291
prompt_embeds_list = []
290292
prompts = [prompt, prompt_2]
@@ -334,14 +336,18 @@ def encode_prompt(
334336
negative_prompt = negative_prompt or ""
335337
negative_prompt_2 = negative_prompt_2 or negative_prompt
336338

339+
# normalize str to list
340+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
341+
negative_prompt_2 = (
342+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
343+
)
344+
337345
uncond_tokens: List[str]
338346
if prompt is not None and type(prompt) is not type(negative_prompt):
339347
raise TypeError(
340348
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
341349
f" {type(prompt)}."
342350
)
343-
elif isinstance(negative_prompt, str):
344-
uncond_tokens = [negative_prompt, negative_prompt_2]
345351
elif batch_size != len(negative_prompt):
346352
raise ValueError(
347353
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,9 @@ def encode_prompt(
419419
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
420420
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
421421

422-
if prompt is not None and isinstance(prompt, str):
423-
batch_size = 1
424-
elif prompt is not None and isinstance(prompt, list):
422+
prompt = [prompt] if isinstance(prompt, str) else prompt
423+
424+
if prompt is not None:
425425
batch_size = len(prompt)
426426
else:
427427
batch_size = prompt_embeds.shape[0]
@@ -434,6 +434,8 @@ def encode_prompt(
434434

435435
if prompt_embeds is None:
436436
prompt_2 = prompt_2 or prompt
437+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
438+
437439
# textual inversion: procecss multi-vector tokens if necessary
438440
prompt_embeds_list = []
439441
prompts = [prompt, prompt_2]
@@ -483,14 +485,18 @@ def encode_prompt(
483485
negative_prompt = negative_prompt or ""
484486
negative_prompt_2 = negative_prompt_2 or negative_prompt
485487

488+
# normalize str to list
489+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
490+
negative_prompt_2 = (
491+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
492+
)
493+
486494
uncond_tokens: List[str]
487495
if prompt is not None and type(prompt) is not type(negative_prompt):
488496
raise TypeError(
489497
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
490498
f" {type(prompt)}."
491499
)
492-
elif isinstance(negative_prompt, str):
493-
uncond_tokens = [negative_prompt, negative_prompt_2]
494500
elif batch_size != len(negative_prompt):
495501
raise ValueError(
496502
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,9 @@ def encode_prompt(
287287
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
288288
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
289289

290-
if prompt is not None and isinstance(prompt, str):
291-
batch_size = 1
292-
elif prompt is not None and isinstance(prompt, list):
290+
prompt = [prompt] if isinstance(prompt, str) else prompt
291+
292+
if prompt is not None:
293293
batch_size = len(prompt)
294294
else:
295295
batch_size = prompt_embeds.shape[0]
@@ -302,6 +302,8 @@ def encode_prompt(
302302

303303
if prompt_embeds is None:
304304
prompt_2 = prompt_2 or prompt
305+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
306+
305307
# textual inversion: procecss multi-vector tokens if necessary
306308
prompt_embeds_list = []
307309
prompts = [prompt, prompt_2]
@@ -351,14 +353,18 @@ def encode_prompt(
351353
negative_prompt = negative_prompt or ""
352354
negative_prompt_2 = negative_prompt_2 or negative_prompt
353355

356+
# normalize str to list
357+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
358+
negative_prompt_2 = (
359+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
360+
)
361+
354362
uncond_tokens: List[str]
355363
if prompt is not None and type(prompt) is not type(negative_prompt):
356364
raise TypeError(
357365
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
358366
f" {type(prompt)}."
359367
)
360-
elif isinstance(negative_prompt, str):
361-
uncond_tokens = [negative_prompt, negative_prompt_2]
362368
elif batch_size != len(negative_prompt):
363369
raise ValueError(
364370
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,42 @@ def test_stable_diffusion_xl_offloads(self):
261261
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
262262
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
263263

264+
def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
265+
components = self.get_dummy_components()
266+
sd_pipe = StableDiffusionXLPipeline(**components)
267+
sd_pipe = sd_pipe.to(torch_device)
268+
sd_pipe.set_progress_bar_config(disable=None)
269+
270+
# forward without prompt embeds
271+
generator_device = "cpu"
272+
inputs = self.get_dummy_inputs(generator_device)
273+
inputs["prompt"] = 3 * [inputs["prompt"]]
274+
275+
output = sd_pipe(**inputs)
276+
image_slice_1 = output.images[0, -3:, -3:, -1]
277+
278+
# forward with prompt embeds
279+
generator_device = "cpu"
280+
inputs = self.get_dummy_inputs(generator_device)
281+
prompt = 3 * [inputs.pop("prompt")]
282+
283+
(
284+
prompt_embeds,
285+
_,
286+
pooled_prompt_embeds,
287+
_,
288+
) = sd_pipe.encode_prompt(prompt)
289+
290+
output = sd_pipe(
291+
**inputs,
292+
prompt_embeds=prompt_embeds,
293+
pooled_prompt_embeds=pooled_prompt_embeds,
294+
)
295+
image_slice_2 = output.images[0, -3:, -3:, -1]
296+
297+
# make sure that it's equal
298+
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
299+
264300
def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
265301
components = self.get_dummy_components()
266302
pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device)

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,42 @@ def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self):
559559
# make sure that it's equal
560560
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
561561

562+
def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
563+
components = self.get_dummy_components()
564+
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
565+
sd_pipe = sd_pipe.to(torch_device)
566+
sd_pipe.set_progress_bar_config(disable=None)
567+
568+
# forward without prompt embeds
569+
generator_device = "cpu"
570+
inputs = self.get_dummy_inputs(generator_device)
571+
inputs["prompt"] = 3 * [inputs["prompt"]]
572+
573+
output = sd_pipe(**inputs)
574+
image_slice_1 = output.images[0, -3:, -3:, -1]
575+
576+
# forward with prompt embeds
577+
generator_device = "cpu"
578+
inputs = self.get_dummy_inputs(generator_device)
579+
prompt = 3 * [inputs.pop("prompt")]
580+
581+
(
582+
prompt_embeds,
583+
_,
584+
pooled_prompt_embeds,
585+
_,
586+
) = sd_pipe.encode_prompt(prompt)
587+
588+
output = sd_pipe(
589+
**inputs,
590+
prompt_embeds=prompt_embeds,
591+
pooled_prompt_embeds=pooled_prompt_embeds,
592+
)
593+
image_slice_2 = output.images[0, -3:, -3:, -1]
594+
595+
# make sure that it's equal
596+
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
597+
562598
def test_attention_slicing_forward_pass(self):
563599
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
564600

0 commit comments

Comments
 (0)