@@ -325,8 +325,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
325325 def __call__ (
326326 self ,
327327 prompt : Union [str , List [str ]],
328- height : int = 512 ,
329- width : int = 512 ,
328+ height : Optional [ int ] = None ,
329+ width : Optional [ int ] = None ,
330330 num_inference_steps : int = 50 ,
331331 guidance_scale : float = 7.5 ,
332332 negative_prompt : Optional [Union [str , List [str ]]] = None ,
@@ -345,9 +345,9 @@ def __call__(
345345 Args:
346346 prompt (`str` or `List[str]`):
347347 The prompt or prompts to guide the image generation.
348- height (`int`, *optional*, defaults to 512 ):
348+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor ):
349349 The height in pixels of the generated image.
350- width (`int`, *optional*, defaults to 512 ):
350+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor ):
351351 The width in pixels of the generated image.
352352 num_inference_steps (`int`, *optional*, defaults to 50):
353353 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -393,6 +393,9 @@ def __call__(
393393 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
394394 (nsfw) content, according to the `safety_checker`.
395395 """
396+ # 0. Default height and width to unet
397+ height = height or self .unet .config .sample_size * self .vae_scale_factor
398+ width = width or self .unet .config .sample_size * self .vae_scale_factor
396399
397400 # 1. Check inputs. Raise error if not correct
398401 self .check_inputs (prompt , height , width , callback_steps )
0 commit comments