|
22 | 22 |
|
23 | 23 | from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock |
24 | 24 | from diffusers.models.embeddings import get_timestep_embedding |
25 | | -from diffusers.models.resnet import Downsample2D, Upsample2D |
| 25 | +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D |
26 | 26 | from diffusers.models.transformer_2d import Transformer2DModel |
27 | 27 | from diffusers.utils import torch_device |
28 | 28 |
|
@@ -222,6 +222,98 @@ def test_downsample_with_conv_out_dim(self): |
222 | 222 | assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
223 | 223 |
|
224 | 224 |
|
| 225 | +class ResnetBlock2DTests(unittest.TestCase): |
| 226 | + def test_resnet_default(self): |
| 227 | + torch.manual_seed(0) |
| 228 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 229 | + temb = torch.randn(1, 128).to(torch_device) |
| 230 | + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device) |
| 231 | + with torch.no_grad(): |
| 232 | + output_tensor = resnet_block(sample, temb) |
| 233 | + |
| 234 | + assert output_tensor.shape == (1, 32, 64, 64) |
| 235 | + output_slice = output_tensor[0, -1, -3:, -3:] |
| 236 | + expected_slice = torch.tensor( |
| 237 | + [-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device |
| 238 | + ) |
| 239 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 240 | + |
| 241 | + def test_restnet_with_use_in_shortcut(self): |
| 242 | + torch.manual_seed(0) |
| 243 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 244 | + temb = torch.randn(1, 128).to(torch_device) |
| 245 | + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device) |
| 246 | + with torch.no_grad(): |
| 247 | + output_tensor = resnet_block(sample, temb) |
| 248 | + |
| 249 | + assert output_tensor.shape == (1, 32, 64, 64) |
| 250 | + output_slice = output_tensor[0, -1, -3:, -3:] |
| 251 | + expected_slice = torch.tensor( |
| 252 | + [0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device |
| 253 | + ) |
| 254 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 255 | + |
| 256 | + def test_resnet_up(self): |
| 257 | + torch.manual_seed(0) |
| 258 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 259 | + temb = torch.randn(1, 128).to(torch_device) |
| 260 | + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device) |
| 261 | + with torch.no_grad(): |
| 262 | + output_tensor = resnet_block(sample, temb) |
| 263 | + |
| 264 | + assert output_tensor.shape == (1, 32, 128, 128) |
| 265 | + output_slice = output_tensor[0, -1, -3:, -3:] |
| 266 | + expected_slice = torch.tensor( |
| 267 | + [1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device |
| 268 | + ) |
| 269 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 270 | + |
| 271 | + def test_resnet_down(self): |
| 272 | + torch.manual_seed(0) |
| 273 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 274 | + temb = torch.randn(1, 128).to(torch_device) |
| 275 | + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device) |
| 276 | + with torch.no_grad(): |
| 277 | + output_tensor = resnet_block(sample, temb) |
| 278 | + |
| 279 | + assert output_tensor.shape == (1, 32, 32, 32) |
| 280 | + output_slice = output_tensor[0, -1, -3:, -3:] |
| 281 | + expected_slice = torch.tensor( |
| 282 | + [-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device |
| 283 | + ) |
| 284 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 285 | + |
| 286 | + def test_restnet_with_kernel_fir(self): |
| 287 | + torch.manual_seed(0) |
| 288 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 289 | + temb = torch.randn(1, 128).to(torch_device) |
| 290 | + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device) |
| 291 | + with torch.no_grad(): |
| 292 | + output_tensor = resnet_block(sample, temb) |
| 293 | + |
| 294 | + assert output_tensor.shape == (1, 32, 32, 32) |
| 295 | + output_slice = output_tensor[0, -1, -3:, -3:] |
| 296 | + expected_slice = torch.tensor( |
| 297 | + [-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device |
| 298 | + ) |
| 299 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 300 | + |
| 301 | + def test_restnet_with_kernel_sde_vp(self): |
| 302 | + torch.manual_seed(0) |
| 303 | + sample = torch.randn(1, 32, 64, 64).to(torch_device) |
| 304 | + temb = torch.randn(1, 128).to(torch_device) |
| 305 | + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device) |
| 306 | + with torch.no_grad(): |
| 307 | + output_tensor = resnet_block(sample, temb) |
| 308 | + |
| 309 | + assert output_tensor.shape == (1, 32, 32, 32) |
| 310 | + output_slice = output_tensor[0, -1, -3:, -3:] |
| 311 | + expected_slice = torch.tensor( |
| 312 | + [-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device |
| 313 | + ) |
| 314 | + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
| 315 | + |
| 316 | + |
225 | 317 | class AttentionBlockTests(unittest.TestCase): |
226 | 318 | @unittest.skipIf( |
227 | 319 | torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039" |
|
0 commit comments