Skip to content

Commit 68f61a0

Browse files
Make sure torch compile doesn't access unet config (huggingface#4008)
1 parent 4a3e574 commit 68f61a0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __init__(
129129
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
130130
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
131131
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
132+
self.default_sample_size = self.unet.config.sample_size
132133

133134
self.watermark = StableDiffusionXLWatermarker()
134135

@@ -652,8 +653,8 @@ def __call__(
652653
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
653654
"""
654655
# 0. Default height and width to unet
655-
height = height or self.unet.config.sample_size * self.vae_scale_factor
656-
width = width or self.unet.config.sample_size * self.vae_scale_factor
656+
height = height or self.default_sample_size * self.vae_scale_factor
657+
width = width or self.default_sample_size * self.vae_scale_factor
657658

658659
original_size = original_size or (height, width)
659660
target_size = target_size or (height, width)

0 commit comments

Comments
 (0)