@@ -326,10 +326,50 @@ def decode_latents(self, latents):
326326 image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
327327 return image
328328
329- def check_inputs (self , prompt , image , noise_level , callback_steps ):
330- if not isinstance (prompt , str ) and not isinstance (prompt , list ):
329+ def check_inputs (
330+ self ,
331+ prompt ,
332+ image ,
333+ noise_level ,
334+ callback_steps ,
335+ negative_prompt = None ,
336+ prompt_embeds = None ,
337+ negative_prompt_embeds = None ,
338+ ):
339+ if (callback_steps is None ) or (
340+ callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
341+ ):
342+ raise ValueError (
343+ f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
344+ f" { type (callback_steps )} ."
345+ )
346+
347+ if prompt is not None and prompt_embeds is not None :
348+ raise ValueError (
349+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
350+ " only forward one of the two."
351+ )
352+ elif prompt is None and prompt_embeds is None :
353+ raise ValueError (
354+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
355+ )
356+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
331357 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
332358
359+ if negative_prompt is not None and negative_prompt_embeds is not None :
360+ raise ValueError (
361+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
362+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
363+ )
364+
365+ if prompt_embeds is not None and negative_prompt_embeds is not None :
366+ if prompt_embeds .shape != negative_prompt_embeds .shape :
367+ raise ValueError (
368+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
369+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
370+ f" { negative_prompt_embeds .shape } ."
371+ )
372+
333373 if (
334374 not isinstance (image , torch .Tensor )
335375 and not isinstance (image , PIL .Image .Image )
@@ -489,13 +529,27 @@ def __call__(
489529 """
490530
491531 # 1. Check inputs
492- self .check_inputs (prompt , image , noise_level , callback_steps )
532+ self .check_inputs (
533+ prompt ,
534+ image ,
535+ noise_level ,
536+ callback_steps ,
537+ negative_prompt ,
538+ prompt_embeds ,
539+ negative_prompt_embeds ,
540+ )
493541
494542 if image is None :
495543 raise ValueError ("`image` input cannot be undefined." )
496544
497545 # 2. Define call parameters
498- batch_size = 1 if isinstance (prompt , str ) else len (prompt )
546+ if prompt is not None and isinstance (prompt , str ):
547+ batch_size = 1
548+ elif prompt is not None and isinstance (prompt , list ):
549+ batch_size = len (prompt )
550+ else :
551+ batch_size = prompt_embeds .shape [0 ]
552+
499553 device = self ._execution_device
500554 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
501555 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
0 commit comments