Skip to content

Commit ff65c2d

Browse files
authored
Don't assume 512x512 in k-diffusion pipeline (huggingface#1625)
Don't assume 512x512 in k-diffusion pipeline.
1 parent f1b726e commit ff65c2d

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
325325
def __call__(
326326
self,
327327
prompt: Union[str, List[str]],
328-
height: int = 512,
329-
width: int = 512,
328+
height: Optional[int] = None,
329+
width: Optional[int] = None,
330330
num_inference_steps: int = 50,
331331
guidance_scale: float = 7.5,
332332
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -345,9 +345,9 @@ def __call__(
345345
Args:
346346
prompt (`str` or `List[str]`):
347347
The prompt or prompts to guide the image generation.
348-
height (`int`, *optional*, defaults to 512):
348+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
349349
The height in pixels of the generated image.
350-
width (`int`, *optional*, defaults to 512):
350+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
351351
The width in pixels of the generated image.
352352
num_inference_steps (`int`, *optional*, defaults to 50):
353353
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -393,6 +393,9 @@ def __call__(
393393
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
394394
(nsfw) content, according to the `safety_checker`.
395395
"""
396+
# 0. Default height and width to unet
397+
height = height or self.unet.config.sample_size * self.vae_scale_factor
398+
width = width or self.unet.config.sample_size * self.vae_scale_factor
396399

397400
# 1. Check inputs. Raise error if not correct
398401
self.check_inputs(prompt, height, width, callback_steps)

0 commit comments

Comments
 (0)