Skip to content

Commit d165984

Browse files
Kaixhinsoumith
authored andcommitted
Use nn.init from core to init SR example (#189)
Makes it clearer what is going on in the init as well. Closes #161
1 parent 6b17f79 commit d165984

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

super_resolution/model.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.init as init
34
from numpy.random import normal
45
from numpy.linalg import svd
56
from math import sqrt
67

78

8-
def _get_orthogonal_init_weights(weights):
9-
fan_out = weights.size(0)
10-
fan_in = weights.size(1) * weights.size(2) * weights.size(3)
11-
12-
u, _, v = svd(normal(0.0, 1.0, (fan_out, fan_in)), full_matrices=False)
13-
14-
if u.shape == (fan_out, fan_in):
15-
return torch.Tensor(u.reshape(weights.size()))
16-
else:
17-
return torch.Tensor(v.reshape(weights.size()))
18-
19-
209
class Net(nn.Module):
2110
def __init__(self, upscale_factor):
2211
super(Net, self).__init__()
@@ -38,7 +27,7 @@ def forward(self, x):
3827
return x
3928

4029
def _initialize_weights(self):
41-
self.conv1.weight.data.copy_(_get_orthogonal_init_weights(self.conv1.weight) * sqrt(2))
42-
self.conv2.weight.data.copy_(_get_orthogonal_init_weights(self.conv2.weight) * sqrt(2))
43-
self.conv3.weight.data.copy_(_get_orthogonal_init_weights(self.conv3.weight) * sqrt(2))
44-
self.conv4.weight.data.copy_(_get_orthogonal_init_weights(self.conv4.weight))
30+
init.orthogonal(self.conv1.weight, init.gain('relu'))
31+
init.orthogonal(self.conv2.weight, init.gain('relu'))
32+
init.orthogonal(self.conv3.weight, init.gain('relu'))
33+
init.orthogonal(self.conv4.weight)

0 commit comments

Comments
 (0)