|
1 | 1 | import torch |
2 | 2 | import unittest |
3 | 3 |
|
| 4 | +import torch.nn as nn |
| 5 | +from torch import Tensor |
| 6 | + |
4 | 7 | from src.losses import StyleLoss |
5 | 8 |
|
6 | 9 | class StyleLossTest(unittest.TestCase): |
7 | 10 | def setUp(self) -> None: |
8 | 11 |
|
9 | 12 | self.criterion = StyleLoss( |
10 | | - enc_depth=(3, 10, 17, 30), |
| 13 | + enc_depth=(5, 12, 19, 32), |
11 | 14 | backbone='vgg19_bn', |
12 | 15 | ) |
13 | 16 |
|
14 | 17 | # Create a dummy input of correct shape |
15 | 18 | self.input_shape = (2, 3, 256, 256) |
16 | 19 |
|
17 | | - self.orig_img = torch.randn(*self.input_shape) |
18 | | - self.targ_sty = torch.randn(*self.input_shape) |
19 | | - self.pred_img = torch.randn(*self.input_shape) |
| 20 | + self.orig_img = torch.randn(*self.input_shape, requires_grad=True) |
| 21 | + self.targ_sty = torch.randn(*self.input_shape, requires_grad=True) |
| 22 | + self.pred_img = torch.randn(*self.input_shape, requires_grad=True) |
| 23 | + |
| 24 | + def test_default_is_relu(self): |
| 25 | + layers = self.criterion.backbone.layers |
20 | 26 |
|
| 27 | + self.assertTrue(all([isinstance(l, nn.ReLU) for l in layers])) |
| 28 | + |
21 | 29 | def test_forward(self): |
22 | 30 | loss = self.criterion(self.orig_img, self.targ_sty, self.pred_img) |
23 | 31 |
|
24 | 32 | # Check the output shape |
25 | 33 | self.assertTrue(loss > 0) |
| 34 | + |
| 35 | + def test_backward(self): |
| 36 | + loss : Tensor = self.criterion(self.orig_img, self.targ_sty, self.pred_img) |
| 37 | + |
| 38 | + loss.backward() |
| 39 | + |
| 40 | + # Check the output shape |
| 41 | + self.assertTrue(loss > 0) |
| 42 | + |
0 commit comments