Skip to content

Commit 2c82e0c

Browse files
authored
Reorganize pipeline tests (huggingface#963)
* Reorganize pipeline tests * fix vq
1 parent 2d35f67 commit 2c82e0c

27 files changed

+993
-495
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import requests
1515
from packaging import version
1616

17-
from .import_utils import is_flax_available, is_torch_available
17+
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
1818

1919

2020
global_rng = random.Random()
@@ -100,13 +100,34 @@ def slow(test_case):
100100
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
101101

102102

103+
def require_torch(test_case):
104+
"""
105+
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
106+
"""
107+
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
108+
109+
110+
def require_torch_gpu(test_case):
111+
"""Decorator marking a test that requires CUDA and PyTorch."""
112+
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
113+
test_case
114+
)
115+
116+
103117
def require_flax(test_case):
104118
"""
105119
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
106120
"""
107121
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
108122

109123

124+
def require_onnxruntime(test_case):
125+
"""
126+
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
127+
"""
128+
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
129+
130+
110131
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
111132
"""
112133
Args:

tests/pipelines/__init__.py

Whitespace-only changes.

tests/pipelines/ddim/__init__.py

Whitespace-only changes.

tests/pipelines/ddim/test_ddim.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
21+
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
22+
from diffusers.utils.testing_utils import require_torch, slow, torch_device
23+
24+
from ...test_pipelines_common import PipelineTesterMixin
25+
26+
27+
torch.backends.cuda.matmul.allow_tf32 = False
28+
29+
30+
class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
31+
@property
32+
def dummy_uncond_unet(self):
33+
torch.manual_seed(0)
34+
model = UNet2DModel(
35+
block_out_channels=(32, 64),
36+
layers_per_block=2,
37+
sample_size=32,
38+
in_channels=3,
39+
out_channels=3,
40+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
41+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
42+
)
43+
return model
44+
45+
def test_inference(self):
46+
unet = self.dummy_uncond_unet
47+
scheduler = DDIMScheduler()
48+
49+
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
50+
ddpm.to(torch_device)
51+
ddpm.set_progress_bar_config(disable=None)
52+
53+
# Warmup pass when using mps (see #372)
54+
if torch_device == "mps":
55+
_ = ddpm(num_inference_steps=1)
56+
57+
generator = torch.manual_seed(0)
58+
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
59+
60+
generator = torch.manual_seed(0)
61+
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
62+
63+
image_slice = image[0, -3:, -3:, -1]
64+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
65+
66+
assert image.shape == (1, 32, 32, 3)
67+
expected_slice = np.array(
68+
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
69+
)
70+
tolerance = 1e-2 if torch_device != "mps" else 3e-2
71+
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
72+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
73+
74+
75+
@slow
76+
@require_torch
77+
class DDIMPipelineIntegrationTests(unittest.TestCase):
78+
def test_inference_ema_bedroom(self):
79+
model_id = "google/ddpm-ema-bedroom-256"
80+
81+
unet = UNet2DModel.from_pretrained(model_id)
82+
scheduler = DDIMScheduler.from_config(model_id)
83+
84+
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
85+
ddpm.to(torch_device)
86+
ddpm.set_progress_bar_config(disable=None)
87+
88+
generator = torch.manual_seed(0)
89+
image = ddpm(generator=generator, output_type="numpy").images
90+
91+
image_slice = image[0, -3:, -3:, -1]
92+
93+
assert image.shape == (1, 256, 256, 3)
94+
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
95+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
96+
97+
def test_inference_cifar10(self):
98+
model_id = "google/ddpm-cifar10-32"
99+
100+
unet = UNet2DModel.from_pretrained(model_id)
101+
scheduler = DDIMScheduler()
102+
103+
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
104+
ddim.to(torch_device)
105+
ddim.set_progress_bar_config(disable=None)
106+
107+
generator = torch.manual_seed(0)
108+
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
109+
110+
image_slice = image[0, -3:, -3:, -1]
111+
112+
assert image.shape == (1, 32, 32, 3)
113+
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
114+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

tests/pipelines/ddpm/__init__.py

Whitespace-only changes.

tests/pipelines/ddpm/test_ddpm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
21+
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
22+
from diffusers.utils.testing_utils import require_torch, slow, torch_device
23+
24+
from ...test_pipelines_common import PipelineTesterMixin
25+
26+
27+
torch.backends.cuda.matmul.allow_tf32 = False
28+
29+
30+
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
31+
# FIXME: add fast tests
32+
pass
33+
34+
35+
@slow
36+
@require_torch
37+
class DDPMPipelineIntegrationTests(unittest.TestCase):
38+
def test_inference_cifar10(self):
39+
model_id = "google/ddpm-cifar10-32"
40+
41+
unet = UNet2DModel.from_pretrained(model_id)
42+
scheduler = DDPMScheduler.from_config(model_id)
43+
44+
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
45+
ddpm.to(torch_device)
46+
ddpm.set_progress_bar_config(disable=None)
47+
48+
generator = torch.manual_seed(0)
49+
image = ddpm(generator=generator, output_type="numpy").images
50+
51+
image_slice = image[0, -3:, -3:, -1]
52+
53+
assert image.shape == (1, 32, 32, 3)
54+
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
55+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

tests/pipelines/karras_ve/__init__.py

Whitespace-only changes.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
21+
from diffusers import KarrasVePipeline, KarrasVeScheduler, UNet2DModel
22+
from diffusers.utils.testing_utils import require_torch, slow, torch_device
23+
24+
from ...test_pipelines_common import PipelineTesterMixin
25+
26+
27+
torch.backends.cuda.matmul.allow_tf32 = False
28+
29+
30+
class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
31+
@property
32+
def dummy_uncond_unet(self):
33+
torch.manual_seed(0)
34+
model = UNet2DModel(
35+
block_out_channels=(32, 64),
36+
layers_per_block=2,
37+
sample_size=32,
38+
in_channels=3,
39+
out_channels=3,
40+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
41+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
42+
)
43+
return model
44+
45+
def test_inference(self):
46+
unet = self.dummy_uncond_unet
47+
scheduler = KarrasVeScheduler()
48+
49+
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
50+
pipe.to(torch_device)
51+
pipe.set_progress_bar_config(disable=None)
52+
53+
generator = torch.manual_seed(0)
54+
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
55+
56+
generator = torch.manual_seed(0)
57+
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
58+
59+
image_slice = image[0, -3:, -3:, -1]
60+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
61+
62+
assert image.shape == (1, 32, 32, 3)
63+
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
64+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
65+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
66+
67+
68+
@slow
69+
@require_torch
70+
class KarrasVePipelineIntegrationTests(unittest.TestCase):
71+
def test_inference(self):
72+
model_id = "google/ncsnpp-celebahq-256"
73+
model = UNet2DModel.from_pretrained(model_id)
74+
scheduler = KarrasVeScheduler()
75+
76+
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
77+
pipe.to(torch_device)
78+
pipe.set_progress_bar_config(disable=None)
79+
80+
generator = torch.manual_seed(0)
81+
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
82+
83+
image_slice = image[0, -3:, -3:, -1]
84+
assert image.shape == (1, 256, 256, 3)
85+
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
86+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

tests/pipelines/latent_diffusion/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)