Skip to content

Commit 5fd3dca

Browse files
yiyixuxuyiyixuxu
andauthored
fix a bug in StableDiffusionUpscalePipeline when prompt is None (huggingface#4278)
* fix batch_size * add test --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent a2091b7 commit 5fd3dca

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,13 @@ def check_inputs(
424424

425425
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
426426
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
427-
if isinstance(prompt, str):
427+
if prompt is not None and isinstance(prompt, str):
428428
batch_size = 1
429-
else:
429+
elif prompt is not None and isinstance(prompt, list):
430430
batch_size = len(prompt)
431+
else:
432+
batch_size = prompt_embeds.shape[0]
433+
431434
if isinstance(image, list):
432435
image_batch_size = len(image)
433436
else:

tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,68 @@ def test_stable_diffusion_upscale_batch(self):
210210
image = output.images
211211
assert image.shape[0] == 2
212212

213+
def test_stable_diffusion_upscale_prompt_embeds(self):
214+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
215+
unet = self.dummy_cond_unet_upscale
216+
low_res_scheduler = DDPMScheduler()
217+
scheduler = DDIMScheduler(prediction_type="v_prediction")
218+
vae = self.dummy_vae
219+
text_encoder = self.dummy_text_encoder
220+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
221+
222+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
223+
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
224+
225+
# make sure here that pndm scheduler skips prk
226+
sd_pipe = StableDiffusionUpscalePipeline(
227+
unet=unet,
228+
low_res_scheduler=low_res_scheduler,
229+
scheduler=scheduler,
230+
vae=vae,
231+
text_encoder=text_encoder,
232+
tokenizer=tokenizer,
233+
max_noise_level=350,
234+
)
235+
sd_pipe = sd_pipe.to(device)
236+
sd_pipe.set_progress_bar_config(disable=None)
237+
238+
prompt = "A painting of a squirrel eating a burger"
239+
generator = torch.Generator(device=device).manual_seed(0)
240+
output = sd_pipe(
241+
[prompt],
242+
image=low_res_image,
243+
generator=generator,
244+
guidance_scale=6.0,
245+
noise_level=20,
246+
num_inference_steps=2,
247+
output_type="np",
248+
)
249+
250+
image = output.images
251+
252+
generator = torch.Generator(device=device).manual_seed(0)
253+
prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False)
254+
image_from_prompt_embeds = sd_pipe(
255+
prompt_embeds=prompt_embeds,
256+
image=[low_res_image],
257+
generator=generator,
258+
guidance_scale=6.0,
259+
noise_level=20,
260+
num_inference_steps=2,
261+
output_type="np",
262+
return_dict=False,
263+
)[0]
264+
265+
image_slice = image[0, -3:, -3:, -1]
266+
image_from_prompt_embeds_slice = image_from_prompt_embeds[0, -3:, -3:, -1]
267+
268+
expected_height_width = low_res_image.size[0] * 4
269+
assert image.shape == (1, expected_height_width, expected_height_width, 3)
270+
expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661])
271+
272+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
273+
assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2
274+
213275
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
214276
def test_stable_diffusion_upscale_fp16(self):
215277
"""Test that stable diffusion upscale works with fp16"""

0 commit comments

Comments
 (0)