Skip to content

Commit f3fbf9b

Browse files
authored
Fix check_inputs in upscaler pipeline to allow embeds (huggingface#2892)
* Remove suggestion to use cuDNN benchmark in docs * removing the wrong line * add support for embeds * fix line length
1 parent e1144ac commit f3fbf9b

File tree

1 file changed

+58
-4
lines changed

1 file changed

+58
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,50 @@ def decode_latents(self, latents):
326326
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
327327
return image
328328

329-
def check_inputs(self, prompt, image, noise_level, callback_steps):
330-
if not isinstance(prompt, str) and not isinstance(prompt, list):
329+
def check_inputs(
330+
self,
331+
prompt,
332+
image,
333+
noise_level,
334+
callback_steps,
335+
negative_prompt=None,
336+
prompt_embeds=None,
337+
negative_prompt_embeds=None,
338+
):
339+
if (callback_steps is None) or (
340+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
341+
):
342+
raise ValueError(
343+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
344+
f" {type(callback_steps)}."
345+
)
346+
347+
if prompt is not None and prompt_embeds is not None:
348+
raise ValueError(
349+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
350+
" only forward one of the two."
351+
)
352+
elif prompt is None and prompt_embeds is None:
353+
raise ValueError(
354+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
355+
)
356+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
331357
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
332358

359+
if negative_prompt is not None and negative_prompt_embeds is not None:
360+
raise ValueError(
361+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
362+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
363+
)
364+
365+
if prompt_embeds is not None and negative_prompt_embeds is not None:
366+
if prompt_embeds.shape != negative_prompt_embeds.shape:
367+
raise ValueError(
368+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
369+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
370+
f" {negative_prompt_embeds.shape}."
371+
)
372+
333373
if (
334374
not isinstance(image, torch.Tensor)
335375
and not isinstance(image, PIL.Image.Image)
@@ -489,13 +529,27 @@ def __call__(
489529
"""
490530

491531
# 1. Check inputs
492-
self.check_inputs(prompt, image, noise_level, callback_steps)
532+
self.check_inputs(
533+
prompt,
534+
image,
535+
noise_level,
536+
callback_steps,
537+
negative_prompt,
538+
prompt_embeds,
539+
negative_prompt_embeds,
540+
)
493541

494542
if image is None:
495543
raise ValueError("`image` input cannot be undefined.")
496544

497545
# 2. Define call parameters
498-
batch_size = 1 if isinstance(prompt, str) else len(prompt)
546+
if prompt is not None and isinstance(prompt, str):
547+
batch_size = 1
548+
elif prompt is not None and isinstance(prompt, list):
549+
batch_size = len(prompt)
550+
else:
551+
batch_size = prompt_embeds.shape[0]
552+
499553
device = self._execution_device
500554
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
501555
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

0 commit comments

Comments
 (0)