Skip to content

Commit 5893305

Browse files
VersatileDiffusion: fix input processing (huggingface#1568)
* fix versatile diffusion input * merge main * `make fix-copies` Co-authored-by: anton- <[email protected]>
1 parent 31444f5 commit 5893305

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ def check_inputs(self, image, height, width, callback_steps):
271271
and not isinstance(image, list)
272272
):
273273
raise ValueError(
274-
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
274+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
275+
f" {type(image)}"
275276
)
276277

277278
if height % 8 != 0 or width % 8 != 0:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def check_inputs(self, image, height, width, callback_steps):
240240
and not isinstance(image, list)
241241
):
242242
raise ValueError(
243-
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
243+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
244+
f" {type(image)}"
244245
)
245246

246247
if height % 8 != 0 or width % 8 != 0:

src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def normalize_embeddings(encoder_output):
134134
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
135135
return embeds
136136

137+
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
138+
prompt = [p for p in prompt]
139+
137140
batch_size = len(prompt) if isinstance(prompt, list) else 1
138141

139142
# get prompt text embeddings
@@ -212,9 +215,17 @@ def prepare_extra_step_kwargs(self, generator, eta):
212215
extra_step_kwargs["generator"] = generator
213216
return extra_step_kwargs
214217

218+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
215219
def check_inputs(self, image, height, width, callback_steps):
216-
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
217-
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
220+
if (
221+
not isinstance(image, torch.Tensor)
222+
and not isinstance(image, PIL.Image.Image)
223+
and not isinstance(image, list)
224+
):
225+
raise ValueError(
226+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
227+
f" {type(image)}"
228+
)
218229

219230
if height % 8 != 0 or width % 8 != 0:
220231
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

0 commit comments

Comments
 (0)