11import torch
22import torch .nn as nn
3+ import torch .nn .init as init
34from numpy .random import normal
45from numpy .linalg import svd
56from math import sqrt
67
78
8- def _get_orthogonal_init_weights (weights ):
9- fan_out = weights .size (0 )
10- fan_in = weights .size (1 ) * weights .size (2 ) * weights .size (3 )
11-
12- u , _ , v = svd (normal (0.0 , 1.0 , (fan_out , fan_in )), full_matrices = False )
13-
14- if u .shape == (fan_out , fan_in ):
15- return torch .Tensor (u .reshape (weights .size ()))
16- else :
17- return torch .Tensor (v .reshape (weights .size ()))
18-
19-
209class Net (nn .Module ):
2110 def __init__ (self , upscale_factor ):
2211 super (Net , self ).__init__ ()
@@ -38,7 +27,7 @@ def forward(self, x):
3827 return x
3928
4029 def _initialize_weights (self ):
41- self . conv1 . weight . data . copy_ ( _get_orthogonal_init_weights ( self .conv1 .weight ) * sqrt ( 2 ))
42- self . conv2 . weight . data . copy_ ( _get_orthogonal_init_weights ( self .conv2 .weight ) * sqrt ( 2 ))
43- self . conv3 . weight . data . copy_ ( _get_orthogonal_init_weights ( self .conv3 .weight ) * sqrt ( 2 ))
44- self . conv4 . weight . data . copy_ ( _get_orthogonal_init_weights ( self .conv4 .weight ) )
30+ init . orthogonal ( self .conv1 .weight , init . gain ( 'relu' ))
31+ init . orthogonal ( self .conv2 .weight , init . gain ( 'relu' ))
32+ init . orthogonal ( self .conv3 .weight , init . gain ( 'relu' ))
33+ init . orthogonal ( self .conv4 .weight )
0 commit comments