@@ -210,6 +210,68 @@ def test_stable_diffusion_upscale_batch(self):
210210 image = output .images
211211 assert image .shape [0 ] == 2
212212
213+ def test_stable_diffusion_upscale_prompt_embeds (self ):
214+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
215+ unet = self .dummy_cond_unet_upscale
216+ low_res_scheduler = DDPMScheduler ()
217+ scheduler = DDIMScheduler (prediction_type = "v_prediction" )
218+ vae = self .dummy_vae
219+ text_encoder = self .dummy_text_encoder
220+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
221+
222+ image = self .dummy_image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
223+ low_res_image = Image .fromarray (np .uint8 (image )).convert ("RGB" ).resize ((64 , 64 ))
224+
225+ # make sure here that pndm scheduler skips prk
226+ sd_pipe = StableDiffusionUpscalePipeline (
227+ unet = unet ,
228+ low_res_scheduler = low_res_scheduler ,
229+ scheduler = scheduler ,
230+ vae = vae ,
231+ text_encoder = text_encoder ,
232+ tokenizer = tokenizer ,
233+ max_noise_level = 350 ,
234+ )
235+ sd_pipe = sd_pipe .to (device )
236+ sd_pipe .set_progress_bar_config (disable = None )
237+
238+ prompt = "A painting of a squirrel eating a burger"
239+ generator = torch .Generator (device = device ).manual_seed (0 )
240+ output = sd_pipe (
241+ [prompt ],
242+ image = low_res_image ,
243+ generator = generator ,
244+ guidance_scale = 6.0 ,
245+ noise_level = 20 ,
246+ num_inference_steps = 2 ,
247+ output_type = "np" ,
248+ )
249+
250+ image = output .images
251+
252+ generator = torch .Generator (device = device ).manual_seed (0 )
253+ prompt_embeds = sd_pipe ._encode_prompt (prompt , device , 1 , False )
254+ image_from_prompt_embeds = sd_pipe (
255+ prompt_embeds = prompt_embeds ,
256+ image = [low_res_image ],
257+ generator = generator ,
258+ guidance_scale = 6.0 ,
259+ noise_level = 20 ,
260+ num_inference_steps = 2 ,
261+ output_type = "np" ,
262+ return_dict = False ,
263+ )[0 ]
264+
265+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
266+ image_from_prompt_embeds_slice = image_from_prompt_embeds [0 , - 3 :, - 3 :, - 1 ]
267+
268+ expected_height_width = low_res_image .size [0 ] * 4
269+ assert image .shape == (1 , expected_height_width , expected_height_width , 3 )
270+ expected_slice = np .array ([0.3113 , 0.3910 , 0.4272 , 0.4859 , 0.5061 , 0.4652 , 0.5362 , 0.5715 , 0.5661 ])
271+
272+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
273+ assert np .abs (image_from_prompt_embeds_slice .flatten () - expected_slice ).max () < 1e-2
274+
213275 @unittest .skipIf (torch_device != "cuda" , "This test requires a GPU" )
214276 def test_stable_diffusion_upscale_fp16 (self ):
215277 """Test that stable diffusion upscale works with fp16"""
0 commit comments