@@ -843,34 +843,36 @@ def test_ldm_text2img_fast(self):
843843
844844 @slow
845845 def test_stable_diffusion (self ):
846- ldm = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1-diffusers" )
846+ pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1-diffusers" )
847847
848848 prompt = "A painting of a squirrel eating a burger"
849849 generator = torch .manual_seed (0 )
850- image = ldm ([prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 20 , output_type = "numpy" )[
850+ image = pipe ([prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 20 , output_type = "numpy" )[
851851 "sample"
852852 ]
853853
854854 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
855855
856- # TODO: update the expected_slice
857856 assert image .shape == (1 , 512 , 512 , 3 )
858- expected_slice = np .array ([0.9256 , 0.9340 , 0.8933 , 0.9361 , 0.9113 , 0.8727 , 0.9122 , 0.8745 , 0.8099 ])
857+ # fmt: off
858+ expected_slice = np .array ([0.09609553 , 0.09020892 , 0.07902172 , 0.07634321 , 0.08755809 , 0.06491277 , 0.07687345 , 0.07173461 , 0.07374045 ])
859+ # fmt: on
859860 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
860861
861862 @slow
862863 def test_stable_diffusion_fast (self ):
863- ldm = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1-diffusers" )
864+ pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1-diffusers" )
864865
865866 prompt = "A painting of a squirrel eating a burger"
866867 generator = torch .manual_seed (0 )
867- image = ldm ([prompt ], generator = generator , num_inference_steps = 1 , output_type = "numpy" )["sample" ]
868+ image = pipe ([prompt ], generator = generator , num_inference_steps = 5 , output_type = "numpy" )["sample" ]
868869
869870 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
870871
871- # TODO: update the expected_slice
872872 assert image .shape == (1 , 512 , 512 , 3 )
873- expected_slice = np .array ([0.3163 , 0.8670 , 0.6465 , 0.1865 , 0.6291 , 0.5139 , 0.2824 , 0.3723 , 0.4344 ])
873+ # fmt: off
874+ expected_slice = np .array ([0.16537648 , 0.17572534 , 0.14657784 , 0.20084214 , 0.19819549 , 0.16032678 , 0.30438453 , 0.22730353 , 0.21307352 ])
875+ # fmt: on
874876 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
875877
876878 @slow
0 commit comments