11import logging
2- import math
32
43import torch .nn as nn
54import torch .utils .checkpoint as cp
5+
6+ from mmcv .cnn import constant_init , kaiming_init
67from mmcv .runner import load_checkpoint
78
89
@@ -27,7 +28,8 @@ def __init__(self,
2728 stride = 1 ,
2829 dilation = 1 ,
2930 downsample = None ,
30- style = 'pytorch' ):
31+ style = 'pytorch' ,
32+ with_cp = False ):
3133 super (BasicBlock , self ).__init__ ()
3234 self .conv1 = conv3x3 (inplanes , planes , stride , dilation )
3335 self .bn1 = nn .BatchNorm2d (planes )
@@ -37,6 +39,7 @@ def __init__(self,
3739 self .downsample = downsample
3840 self .stride = stride
3941 self .dilation = dilation
42+ assert not with_cp
4043
4144 def forward (self , x ):
4245 residual = x
@@ -69,7 +72,6 @@ def __init__(self,
6972 style = 'pytorch' ,
7073 with_cp = False ):
7174 """Bottleneck block.
72-
7375 If style is "pytorch", the stride-two layer is the 3x3 conv layer,
7476 if it is "caffe", the stride-two layer is the first 1x1 conv layer.
7577 """
@@ -174,64 +176,73 @@ def make_res_layer(block,
174176 return nn .Sequential (* layers )
175177
176178
177- class ResHead (nn .Module ):
178-
179- def __init__ (self ,
180- block ,
181- num_blocks ,
182- stride = 2 ,
183- dilation = 1 ,
184- style = 'pytorch' ):
185- self .layer4 = make_res_layer (
186- block ,
187- 1024 ,
188- 512 ,
189- num_blocks ,
190- stride = stride ,
191- dilation = dilation ,
192- style = style )
193-
194- def forward (self , x ):
195- return self .layer4 (x )
179+ class ResNet (nn .Module ):
180+ """ResNet backbone.
196181
182+ Args:
183+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
184+ num_stages (int): Resnet stages, normally 4.
185+ strides (Sequence[int]): Strides of the first block of each stage.
186+ dilations (Sequence[int]): Dilation of each stage.
187+ out_indices (Sequence[int]): Output from which stages.
188+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
189+ layer is the 3x3 conv layer, otherwise the stride-two layer is
190+ the first 1x1 conv layer.
191+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
192+ not freezing any parameters.
193+ bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
194+ running stats (mean and var).
195+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
196+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
197+ memory while slowing down the training speed.
198+ """
197199
198- class ResNet (nn .Module ):
200+ arch_settings = {
201+ 18 : (BasicBlock , (2 , 2 , 2 , 2 )),
202+ 34 : (BasicBlock , (3 , 4 , 6 , 3 )),
203+ 50 : (Bottleneck , (3 , 4 , 6 , 3 )),
204+ 101 : (Bottleneck , (3 , 4 , 23 , 3 )),
205+ 152 : (Bottleneck , (3 , 8 , 36 , 3 ))
206+ }
199207
200208 def __init__ (self ,
201- block ,
202- layers ,
209+ depth ,
210+ num_stages = 4 ,
203211 strides = (1 , 2 , 2 , 2 ),
204212 dilations = (1 , 1 , 1 , 1 ),
205213 out_indices = (0 , 1 , 2 , 3 ),
206- frozen_stages = - 1 ,
207214 style = 'pytorch' ,
208- sync_bn = False ,
209- with_cp = False ,
210- strict_frozen = False ):
215+ frozen_stages = - 1 ,
216+ bn_eval = True ,
217+ bn_frozen = False ,
218+ with_cp = False ):
211219 super (ResNet , self ).__init__ ()
212- if not len (layers ) == len (strides ) == len (dilations ):
213- raise ValueError (
214- 'The number of layers, strides and dilations must be equal, '
215- 'but found have {} layers, {} strides and {} dilations' .format (
216- len (layers ), len (strides ), len (dilations )))
217- assert max (out_indices ) < len (layers )
220+ if depth not in self .arch_settings :
221+ raise KeyError ('invalid depth {} for resnet' .format (depth ))
222+ assert num_stages >= 1 and num_stages <= 4
223+ block , stage_blocks = self .arch_settings [depth ]
224+ stage_blocks = stage_blocks [:num_stages ]
225+ assert len (strides ) == len (dilations ) == num_stages
226+ assert max (out_indices ) < num_stages
227+
218228 self .out_indices = out_indices
219- self .frozen_stages = frozen_stages
220229 self .style = style
221- self .sync_bn = sync_bn
230+ self .frozen_stages = frozen_stages
231+ self .bn_eval = bn_eval
232+ self .bn_frozen = bn_frozen
233+ self .with_cp = with_cp
234+
222235 self .inplanes = 64
223236 self .conv1 = nn .Conv2d (
224237 3 , 64 , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
225238 self .bn1 = nn .BatchNorm2d (64 )
226239 self .relu = nn .ReLU (inplace = True )
227240 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
228- self .res_layers = []
229- for i , num_blocks in enumerate (layers ):
230241
242+ self .res_layers = []
243+ for i , num_blocks in enumerate (stage_blocks ):
231244 stride = strides [i ]
232245 dilation = dilations [i ]
233-
234- layer_name = 'layer{}' .format (i + 1 )
235246 planes = 64 * 2 ** i
236247 res_layer = make_res_layer (
237248 block ,
@@ -243,12 +254,11 @@ def __init__(self,
243254 style = self .style ,
244255 with_cp = with_cp )
245256 self .inplanes = planes * block .expansion
257+ layer_name = 'layer{}' .format (i + 1 )
246258 self .add_module (layer_name , res_layer )
247259 self .res_layers .append (layer_name )
248- self .feat_dim = block .expansion * 64 * 2 ** (len (layers ) - 1 )
249- self .with_cp = with_cp
250260
251- self .strict_frozen = strict_frozen
261+ self .feat_dim = block . expansion * 64 * 2 ** ( len ( stage_blocks ) - 1 )
252262
253263 def init_weights (self , pretrained = None ):
254264 if isinstance (pretrained , str ):
@@ -257,11 +267,9 @@ def init_weights(self, pretrained=None):
257267 elif pretrained is None :
258268 for m in self .modules ():
259269 if isinstance (m , nn .Conv2d ):
260- n = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
261- nn .init .normal_ (m .weight , 0 , math .sqrt (2. / n ))
270+ kaiming_init (m )
262271 elif isinstance (m , nn .BatchNorm2d ):
263- nn .init .constant_ (m .weight , 1 )
264- nn .init .constant_ (m .bias , 0 )
272+ constant_init (m , 1 )
265273 else :
266274 raise TypeError ('pretrained must be a str or None' )
267275
@@ -283,11 +291,11 @@ def forward(self, x):
283291
284292 def train (self , mode = True ):
285293 super (ResNet , self ).train (mode )
286- if not self .sync_bn :
294+ if self .bn_eval :
287295 for m in self .modules ():
288296 if isinstance (m , nn .BatchNorm2d ):
289297 m .eval ()
290- if self .strict_frozen :
298+ if self .bn_frozen :
291299 for params in m .parameters ():
292300 params .requires_grad = False
293301 if mode and self .frozen_stages >= 0 :
@@ -303,39 +311,3 @@ def train(self, mode=True):
303311 mod .eval ()
304312 for param in mod .parameters ():
305313 param .requires_grad = False
306-
307-
308- resnet_cfg = {
309- 18 : (BasicBlock , (2 , 2 , 2 , 2 )),
310- 34 : (BasicBlock , (3 , 4 , 6 , 3 )),
311- 50 : (Bottleneck , (3 , 4 , 6 , 3 )),
312- 101 : (Bottleneck , (3 , 4 , 23 , 3 )),
313- 152 : (Bottleneck , (3 , 8 , 36 , 3 ))
314- }
315-
316-
317- def resnet (depth ,
318- num_stages = 4 ,
319- strides = (1 , 2 , 2 , 2 ),
320- dilations = (1 , 1 , 1 , 1 ),
321- out_indices = (2 , ),
322- frozen_stages = - 1 ,
323- style = 'pytorch' ,
324- sync_bn = False ,
325- with_cp = False ,
326- strict_frozen = False ):
327- """Constructs a ResNet model.
328-
329- Args:
330- depth (int): depth of resnet, from {18, 34, 50, 101, 152}
331- num_stages (int): num of resnet stages, normally 4
332- strides (list): strides of the first block of each stage
333- dilations (list): dilation of each stage
334- out_indices (list): output from which stages
335- """
336- if depth not in resnet_cfg :
337- raise KeyError ('invalid depth {} for resnet' .format (depth ))
338- block , layers = resnet_cfg [depth ]
339- model = ResNet (block , layers [:num_stages ], strides , dilations , out_indices ,
340- frozen_stages , style , sync_bn , with_cp , strict_frozen )
341- return model
0 commit comments