Skip to content

Commit 0e747be

Browse files
committed
update resnet backbone
1 parent e8397e4 commit 0e747be

File tree

5 files changed

+63
-91
lines changed

5 files changed

+63
-91
lines changed

configs/faster_rcnn_r50_fpn_1x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
type='FasterRCNN',
44
pretrained='modelzoo://resnet50',
55
backbone=dict(
6-
type='resnet',
6+
type='ResNet',
77
depth=50,
88
num_stages=4,
99
out_indices=(0, 1, 2, 3),

configs/mask_rcnn_r50_fpn_1x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
type='MaskRCNN',
44
pretrained='modelzoo://resnet50',
55
backbone=dict(
6-
type='resnet',
6+
type='ResNet',
77
depth=50,
88
num_stages=4,
99
out_indices=(0, 1, 2, 3),

configs/rpn_r50_fpn_1x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
type='RPN',
44
pretrained='modelzoo://resnet50',
55
backbone=dict(
6-
type='resnet',
6+
type='ResNet',
77
depth=50,
88
num_stages=4,
99
out_indices=(0, 1, 2, 3),

mmdet/models/backbones/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .resnet import resnet
1+
from .resnet import ResNet
22

3-
__all__ = ['resnet']
3+
__all__ = ['ResNet']

mmdet/models/backbones/resnet.py

Lines changed: 58 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
2-
import math
32

43
import torch.nn as nn
54
import torch.utils.checkpoint as cp
5+
6+
from mmcv.cnn import constant_init, kaiming_init
67
from mmcv.runner import load_checkpoint
78

89

@@ -27,7 +28,8 @@ def __init__(self,
2728
stride=1,
2829
dilation=1,
2930
downsample=None,
30-
style='pytorch'):
31+
style='pytorch',
32+
with_cp=False):
3133
super(BasicBlock, self).__init__()
3234
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
3335
self.bn1 = nn.BatchNorm2d(planes)
@@ -37,6 +39,7 @@ def __init__(self,
3739
self.downsample = downsample
3840
self.stride = stride
3941
self.dilation = dilation
42+
assert not with_cp
4043

4144
def forward(self, x):
4245
residual = x
@@ -69,7 +72,6 @@ def __init__(self,
6972
style='pytorch',
7073
with_cp=False):
7174
"""Bottleneck block.
72-
7375
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
7476
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
7577
"""
@@ -174,64 +176,73 @@ def make_res_layer(block,
174176
return nn.Sequential(*layers)
175177

176178

177-
class ResHead(nn.Module):
178-
179-
def __init__(self,
180-
block,
181-
num_blocks,
182-
stride=2,
183-
dilation=1,
184-
style='pytorch'):
185-
self.layer4 = make_res_layer(
186-
block,
187-
1024,
188-
512,
189-
num_blocks,
190-
stride=stride,
191-
dilation=dilation,
192-
style=style)
193-
194-
def forward(self, x):
195-
return self.layer4(x)
179+
class ResNet(nn.Module):
180+
"""ResNet backbone.
196181
182+
Args:
183+
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
184+
num_stages (int): Resnet stages, normally 4.
185+
strides (Sequence[int]): Strides of the first block of each stage.
186+
dilations (Sequence[int]): Dilation of each stage.
187+
out_indices (Sequence[int]): Output from which stages.
188+
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
189+
layer is the 3x3 conv layer, otherwise the stride-two layer is
190+
the first 1x1 conv layer.
191+
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
192+
not freezing any parameters.
193+
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
194+
running stats (mean and var).
195+
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
196+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
197+
memory while slowing down the training speed.
198+
"""
197199

198-
class ResNet(nn.Module):
200+
arch_settings = {
201+
18: (BasicBlock, (2, 2, 2, 2)),
202+
34: (BasicBlock, (3, 4, 6, 3)),
203+
50: (Bottleneck, (3, 4, 6, 3)),
204+
101: (Bottleneck, (3, 4, 23, 3)),
205+
152: (Bottleneck, (3, 8, 36, 3))
206+
}
199207

200208
def __init__(self,
201-
block,
202-
layers,
209+
depth,
210+
num_stages=4,
203211
strides=(1, 2, 2, 2),
204212
dilations=(1, 1, 1, 1),
205213
out_indices=(0, 1, 2, 3),
206-
frozen_stages=-1,
207214
style='pytorch',
208-
sync_bn=False,
209-
with_cp=False,
210-
strict_frozen=False):
215+
frozen_stages=-1,
216+
bn_eval=True,
217+
bn_frozen=False,
218+
with_cp=False):
211219
super(ResNet, self).__init__()
212-
if not len(layers) == len(strides) == len(dilations):
213-
raise ValueError(
214-
'The number of layers, strides and dilations must be equal, '
215-
'but found have {} layers, {} strides and {} dilations'.format(
216-
len(layers), len(strides), len(dilations)))
217-
assert max(out_indices) < len(layers)
220+
if depth not in self.arch_settings:
221+
raise KeyError('invalid depth {} for resnet'.format(depth))
222+
assert num_stages >= 1 and num_stages <= 4
223+
block, stage_blocks = self.arch_settings[depth]
224+
stage_blocks = stage_blocks[:num_stages]
225+
assert len(strides) == len(dilations) == num_stages
226+
assert max(out_indices) < num_stages
227+
218228
self.out_indices = out_indices
219-
self.frozen_stages = frozen_stages
220229
self.style = style
221-
self.sync_bn = sync_bn
230+
self.frozen_stages = frozen_stages
231+
self.bn_eval = bn_eval
232+
self.bn_frozen = bn_frozen
233+
self.with_cp = with_cp
234+
222235
self.inplanes = 64
223236
self.conv1 = nn.Conv2d(
224237
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
225238
self.bn1 = nn.BatchNorm2d(64)
226239
self.relu = nn.ReLU(inplace=True)
227240
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
228-
self.res_layers = []
229-
for i, num_blocks in enumerate(layers):
230241

242+
self.res_layers = []
243+
for i, num_blocks in enumerate(stage_blocks):
231244
stride = strides[i]
232245
dilation = dilations[i]
233-
234-
layer_name = 'layer{}'.format(i + 1)
235246
planes = 64 * 2**i
236247
res_layer = make_res_layer(
237248
block,
@@ -243,12 +254,11 @@ def __init__(self,
243254
style=self.style,
244255
with_cp=with_cp)
245256
self.inplanes = planes * block.expansion
257+
layer_name = 'layer{}'.format(i + 1)
246258
self.add_module(layer_name, res_layer)
247259
self.res_layers.append(layer_name)
248-
self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
249-
self.with_cp = with_cp
250260

251-
self.strict_frozen = strict_frozen
261+
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
252262

253263
def init_weights(self, pretrained=None):
254264
if isinstance(pretrained, str):
@@ -257,11 +267,9 @@ def init_weights(self, pretrained=None):
257267
elif pretrained is None:
258268
for m in self.modules():
259269
if isinstance(m, nn.Conv2d):
260-
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
261-
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
270+
kaiming_init(m)
262271
elif isinstance(m, nn.BatchNorm2d):
263-
nn.init.constant_(m.weight, 1)
264-
nn.init.constant_(m.bias, 0)
272+
constant_init(m, 1)
265273
else:
266274
raise TypeError('pretrained must be a str or None')
267275

@@ -283,11 +291,11 @@ def forward(self, x):
283291

284292
def train(self, mode=True):
285293
super(ResNet, self).train(mode)
286-
if not self.sync_bn:
294+
if self.bn_eval:
287295
for m in self.modules():
288296
if isinstance(m, nn.BatchNorm2d):
289297
m.eval()
290-
if self.strict_frozen:
298+
if self.bn_frozen:
291299
for params in m.parameters():
292300
params.requires_grad = False
293301
if mode and self.frozen_stages >= 0:
@@ -303,39 +311,3 @@ def train(self, mode=True):
303311
mod.eval()
304312
for param in mod.parameters():
305313
param.requires_grad = False
306-
307-
308-
resnet_cfg = {
309-
18: (BasicBlock, (2, 2, 2, 2)),
310-
34: (BasicBlock, (3, 4, 6, 3)),
311-
50: (Bottleneck, (3, 4, 6, 3)),
312-
101: (Bottleneck, (3, 4, 23, 3)),
313-
152: (Bottleneck, (3, 8, 36, 3))
314-
}
315-
316-
317-
def resnet(depth,
318-
num_stages=4,
319-
strides=(1, 2, 2, 2),
320-
dilations=(1, 1, 1, 1),
321-
out_indices=(2, ),
322-
frozen_stages=-1,
323-
style='pytorch',
324-
sync_bn=False,
325-
with_cp=False,
326-
strict_frozen=False):
327-
"""Constructs a ResNet model.
328-
329-
Args:
330-
depth (int): depth of resnet, from {18, 34, 50, 101, 152}
331-
num_stages (int): num of resnet stages, normally 4
332-
strides (list): strides of the first block of each stage
333-
dilations (list): dilation of each stage
334-
out_indices (list): output from which stages
335-
"""
336-
if depth not in resnet_cfg:
337-
raise KeyError('invalid depth {} for resnet'.format(depth))
338-
block, layers = resnet_cfg[depth]
339-
model = ResNet(block, layers[:num_stages], strides, dilations, out_indices,
340-
frozen_stages, style, sync_bn, with_cp, strict_frozen)
341-
return model

0 commit comments

Comments
 (0)