Skip to content

Commit 609a3ea

Browse files
committed
Fixed bug in loss, added backward test
1 parent 3dc7439 commit 609a3ea

File tree

4 files changed

+43
-9
lines changed

4 files changed

+43
-9
lines changed

src/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class StyleLoss(nn.Module):
5858

5959
def __init__(
6060
self,
61-
enc_depth : List[int] = (3, 10, 17, 30),
61+
enc_depth : List[int] = (5, 12, 19, 32),
6262
backbone : str = 'vgg19_bn',
6363
content_weight : float = .8,
6464
) -> None:

src/recorder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@ def __init__(
2323
self.feats = {l : [] for l in names}
2424

2525
def __call__(self, module : nn.Module, inp : Tensor, out : Tensor) -> None:
26-
# Detach layer output from PyTorch graph
27-
data = out.detach()
28-
2926
# Get the module name
3027
layer = module.name
3128

32-
self.feats[layer].append(data)
29+
# NOTE: It's important NOT to detach output from grad graph
30+
self.feats[layer].append(out)
3331

3432
def clean(
3533
self,

test/test_hflow.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import torch
22
import unittest
33

4+
from torch import Tensor
5+
46
from src.hflow import HierarchyFlow
7+
from src.losses import StyleLoss
58

69
class HFlowTest(unittest.TestCase):
710
def setUp(self) -> None:
@@ -19,6 +22,11 @@ def setUp(self) -> None:
1922
]
2023
)
2124

25+
self.criterion = StyleLoss(
26+
enc_depth=(5, 12, 19, 32),
27+
backbone='vgg19_bn',
28+
)
29+
2230
# Create a dummy input of correct shape
2331
self.input_shape = (2, 3, 256, 256)
2432

@@ -30,3 +38,14 @@ def test_forward(self):
3038

3139
# Check the output shape
3240
self.assertEqual(output.shape, self.input_shape)
41+
42+
def test_loss_backward(self):
43+
pred_img = self.flow(self.dummy_subj, self.dummy_style)
44+
45+
loss : Tensor = self.criterion(self.dummy_subj, self.dummy_style, pred_img)
46+
47+
loss.backward()
48+
49+
# Check the output shape
50+
self.assertTrue(loss > 0)
51+

test/test_loss.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,42 @@
11
import torch
22
import unittest
33

4+
import torch.nn as nn
5+
from torch import Tensor
6+
47
from src.losses import StyleLoss
58

69
class StyleLossTest(unittest.TestCase):
710
def setUp(self) -> None:
811

912
self.criterion = StyleLoss(
10-
enc_depth=(3, 10, 17, 30),
13+
enc_depth=(5, 12, 19, 32),
1114
backbone='vgg19_bn',
1215
)
1316

1417
# Create a dummy input of correct shape
1518
self.input_shape = (2, 3, 256, 256)
1619

17-
self.orig_img = torch.randn(*self.input_shape)
18-
self.targ_sty = torch.randn(*self.input_shape)
19-
self.pred_img = torch.randn(*self.input_shape)
20+
self.orig_img = torch.randn(*self.input_shape, requires_grad=True)
21+
self.targ_sty = torch.randn(*self.input_shape, requires_grad=True)
22+
self.pred_img = torch.randn(*self.input_shape, requires_grad=True)
23+
24+
def test_default_is_relu(self):
25+
layers = self.criterion.backbone.layers
2026

27+
self.assertTrue(all([isinstance(l, nn.ReLU) for l in layers]))
28+
2129
def test_forward(self):
2230
loss = self.criterion(self.orig_img, self.targ_sty, self.pred_img)
2331

2432
# Check the output shape
2533
self.assertTrue(loss > 0)
34+
35+
def test_backward(self):
36+
loss : Tensor = self.criterion(self.orig_img, self.targ_sty, self.pred_img)
37+
38+
loss.backward()
39+
40+
# Check the output shape
41+
self.assertTrue(loss > 0)
42+

0 commit comments

Comments
 (0)