Skip to content

Commit d43972a

Browse files
authored
Fixes prompt input checks in StableDiffusion img2img pipeline (huggingface#2206)
* Fixes prompt input checks in img2img Allows providing prompt_embeds instead of the prompt, which is not currently possible as the first check fails. This becomes the same as the function found in https://github.com/huggingface/diffusers/blob/8267c7844504b55366525169187767ef92d1f499/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L393 * Continues the fix This also needs to be fixed. Becomes consistent with https://github.com/huggingface/diffusers/blob/8267c7844504b55366525169187767ef92d1f499/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L558 I've now tested this implementation, and it produces the expected results.
1 parent ffed242 commit d43972a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,6 @@ def prepare_extra_step_kwargs(self, generator, eta):
428428
def check_inputs(
429429
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
430430
):
431-
if not isinstance(prompt, str) and not isinstance(prompt, list):
432-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
433-
434431
if strength < 0 or strength > 1:
435432
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
436433

@@ -623,7 +620,12 @@ def __call__(
623620
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
624621

625622
# 2. Define call parameters
626-
batch_size = 1 if isinstance(prompt, str) else len(prompt)
623+
if prompt is not None and isinstance(prompt, str):
624+
batch_size = 1
625+
elif prompt is not None and isinstance(prompt, list):
626+
batch_size = len(prompt)
627+
else:
628+
batch_size = prompt_embeds.shape[0]
627629
device = self._execution_device
628630
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
629631
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

0 commit comments

Comments
 (0)