Skip to content

Commit 9e17983

Browse files
authored
Test ResnetBlock2D (huggingface#1850)
* test resnet block * fix code format required by isort * add torch device * nit
1 parent cb8a3db commit 9e17983

File tree

1 file changed

+93
-1
lines changed

1 file changed

+93
-1
lines changed

tests/test_layers_utils.py

100755100644
Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
2424
from diffusers.models.embeddings import get_timestep_embedding
25-
from diffusers.models.resnet import Downsample2D, Upsample2D
25+
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
2626
from diffusers.models.transformer_2d import Transformer2DModel
2727
from diffusers.utils import torch_device
2828

@@ -222,6 +222,98 @@ def test_downsample_with_conv_out_dim(self):
222222
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
223223

224224

225+
class ResnetBlock2DTests(unittest.TestCase):
226+
def test_resnet_default(self):
227+
torch.manual_seed(0)
228+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
229+
temb = torch.randn(1, 128).to(torch_device)
230+
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device)
231+
with torch.no_grad():
232+
output_tensor = resnet_block(sample, temb)
233+
234+
assert output_tensor.shape == (1, 32, 64, 64)
235+
output_slice = output_tensor[0, -1, -3:, -3:]
236+
expected_slice = torch.tensor(
237+
[-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device
238+
)
239+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
240+
241+
def test_restnet_with_use_in_shortcut(self):
242+
torch.manual_seed(0)
243+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
244+
temb = torch.randn(1, 128).to(torch_device)
245+
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device)
246+
with torch.no_grad():
247+
output_tensor = resnet_block(sample, temb)
248+
249+
assert output_tensor.shape == (1, 32, 64, 64)
250+
output_slice = output_tensor[0, -1, -3:, -3:]
251+
expected_slice = torch.tensor(
252+
[0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device
253+
)
254+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
255+
256+
def test_resnet_up(self):
257+
torch.manual_seed(0)
258+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
259+
temb = torch.randn(1, 128).to(torch_device)
260+
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device)
261+
with torch.no_grad():
262+
output_tensor = resnet_block(sample, temb)
263+
264+
assert output_tensor.shape == (1, 32, 128, 128)
265+
output_slice = output_tensor[0, -1, -3:, -3:]
266+
expected_slice = torch.tensor(
267+
[1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device
268+
)
269+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
270+
271+
def test_resnet_down(self):
272+
torch.manual_seed(0)
273+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
274+
temb = torch.randn(1, 128).to(torch_device)
275+
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device)
276+
with torch.no_grad():
277+
output_tensor = resnet_block(sample, temb)
278+
279+
assert output_tensor.shape == (1, 32, 32, 32)
280+
output_slice = output_tensor[0, -1, -3:, -3:]
281+
expected_slice = torch.tensor(
282+
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
283+
)
284+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
285+
286+
def test_restnet_with_kernel_fir(self):
287+
torch.manual_seed(0)
288+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
289+
temb = torch.randn(1, 128).to(torch_device)
290+
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device)
291+
with torch.no_grad():
292+
output_tensor = resnet_block(sample, temb)
293+
294+
assert output_tensor.shape == (1, 32, 32, 32)
295+
output_slice = output_tensor[0, -1, -3:, -3:]
296+
expected_slice = torch.tensor(
297+
[-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device
298+
)
299+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
300+
301+
def test_restnet_with_kernel_sde_vp(self):
302+
torch.manual_seed(0)
303+
sample = torch.randn(1, 32, 64, 64).to(torch_device)
304+
temb = torch.randn(1, 128).to(torch_device)
305+
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device)
306+
with torch.no_grad():
307+
output_tensor = resnet_block(sample, temb)
308+
309+
assert output_tensor.shape == (1, 32, 32, 32)
310+
output_slice = output_tensor[0, -1, -3:, -3:]
311+
expected_slice = torch.tensor(
312+
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
313+
)
314+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
315+
316+
225317
class AttentionBlockTests(unittest.TestCase):
226318
@unittest.skipIf(
227319
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"

0 commit comments

Comments
 (0)