Skip to content

Commit c25d8c9

Browse files
authored
add tests for stable diffusion pipeline (huggingface#178)
add tests for sd pipeline
1 parent 5782e03 commit c25d8c9

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

tests/test_modeling_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -843,34 +843,36 @@ def test_ldm_text2img_fast(self):
843843

844844
@slow
845845
def test_stable_diffusion(self):
846-
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
846+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
847847

848848
prompt = "A painting of a squirrel eating a burger"
849849
generator = torch.manual_seed(0)
850-
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
850+
image = pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
851851
"sample"
852852
]
853853

854854
image_slice = image[0, -3:, -3:, -1]
855855

856-
# TODO: update the expected_slice
857856
assert image.shape == (1, 512, 512, 3)
858-
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
857+
# fmt: off
858+
expected_slice = np.array([0.09609553, 0.09020892, 0.07902172, 0.07634321, 0.08755809, 0.06491277, 0.07687345, 0.07173461, 0.07374045])
859+
# fmt: on
859860
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
860861

861862
@slow
862863
def test_stable_diffusion_fast(self):
863-
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
864+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
864865

865866
prompt = "A painting of a squirrel eating a burger"
866867
generator = torch.manual_seed(0)
867-
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
868+
image = pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
868869

869870
image_slice = image[0, -3:, -3:, -1]
870871

871-
# TODO: update the expected_slice
872872
assert image.shape == (1, 512, 512, 3)
873-
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
873+
# fmt: off
874+
expected_slice = np.array([0.16537648, 0.17572534, 0.14657784, 0.20084214, 0.19819549, 0.16032678, 0.30438453, 0.22730353, 0.21307352])
875+
# fmt: on
874876
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
875877

876878
@slow

0 commit comments

Comments
 (0)