11import warnings
22
33import torch .nn as nn
4+ from mmcv .cnn import kaiming_init , constant_init
45
56from .norm import build_norm_layer
67
@@ -51,15 +52,8 @@ def __init__(self,
5152 self .groups = self .conv .groups
5253
5354 if self .with_norm :
54- # self.norm_type, self.norm_params = parse_norm(normalize)
55- # assert self.norm_type in [None, 'BN', 'SyncBN', 'GN', 'SN']
56- # self.Norm2d = norm_cfg[self.norm_type]
57- if self .activate_last :
58- self .norm = build_norm_layer (normalize , out_channels )
59- # self.norm = self.Norm2d(out_channels, **self.norm_params)
60- else :
61- self .norm = build_norm_layer (normalize , in_channels )
62- # self.norm = self.Norm2d(in_channels, **self.norm_params)
55+ norm_channels = out_channels if self .activate_last else in_channels
56+ self .norm = build_norm_layer (normalize , norm_channels )
6357
6458 if self .with_activatation :
6559 assert activation in ['relu' ], 'Only ReLU supported.'
@@ -71,13 +65,9 @@ def __init__(self,
7165
7266 def init_weights (self ):
7367 nonlinearity = 'relu' if self .activation is None else self .activation
74- nn .init .kaiming_normal_ (
75- self .conv .weight , mode = 'fan_out' , nonlinearity = nonlinearity )
76- if self .with_bias :
77- nn .init .constant_ (self .conv .bias , 0 )
68+ kaiming_init (self .conv , nonlinearity = nonlinearity )
7869 if self .with_norm :
79- nn .init .constant_ (self .norm .weight , 1 )
80- nn .init .constant_ (self .norm .bias , 0 )
70+ constant_init (self .norm , 1 , bias = 0 )
8171
8272 def forward (self , x , activate = True , norm = True ):
8373 if self .activate_last :
0 commit comments