Skip to content

Commit 773b768

Browse files
committed
Move init_module in Network class to share the initialization.
1 parent 529f97e commit 773b768

File tree

3 files changed

+21
-26
lines changed

3 files changed

+21
-26
lines changed

lib/nets/network.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,21 @@ def create_architecture(self, num_classes, tag=None,
310310
# Initialize layers
311311
self._init_modules()
312312

313+
def _init_modules(self):
314+
self._init_head_tail()
315+
316+
# rpn
317+
self.rpn_net = nn.Conv2d(self._net_conv_channels, cfg.RPN_CHANNELS, [3, 3], padding=1)
318+
319+
self.rpn_cls_score_net = nn.Conv2d(cfg.RPN_CHANNELS, self._num_anchors * 2, [1, 1])
320+
321+
self.rpn_bbox_pred_net = nn.Conv2d(cfg.RPN_CHANNELS, self._num_anchors * 4, [1, 1])
322+
323+
self.cls_score_net = nn.Linear(self._fc7_channels, self._num_classes)
324+
self.bbox_pred_net = nn.Linear(self._fc7_channels, self._num_classes * 4)
325+
326+
self.init_weights()
327+
313328
def _run_summary_op(self, val=False):
314329
"""
315330
Run the summary operator: feed the placeholders with corresponding newtork outputs(activations)

lib/nets/resnet_v1.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ def __init__(self, num_layers=50):
211211
self._feat_stride = [16, ]
212212
self._feat_compress = [1. / float(self._feat_stride[0]), ]
213213
self._num_layers = num_layers
214+
self._net_conv_channels = 1024
215+
self._fc7_channels = 2048
214216

215217
def _crop_pool_layer(self, bottom, rois):
216218
return Network._crop_pool_layer(self, bottom, rois, cfg.RESNET.MAX_POOL)
@@ -225,7 +227,7 @@ def _head_to_tail(self, pool5):
225227
fc7 = self.resnet.layer4(pool5).mean(3).mean(2) # average pooling after layer4
226228
return fc7
227229

228-
def _init_modules(self):
230+
def _init_head_tail(self):
229231
# choose different blocks for different number of layers
230232
if self._num_layers == 50:
231233
self.resnet = resnet50()
@@ -262,18 +264,6 @@ def set_bn_fix(m):
262264
self._layers['head'] = nn.Sequential(self.resnet.conv1, self.resnet.bn1,self.resnet.relu,
263265
self.resnet.maxpool,self.resnet.layer1,self.resnet.layer2,self.resnet.layer3)
264266

265-
# rpn
266-
self.rpn_net = nn.Conv2d(1024, 512, [3, 3], padding=1)
267-
268-
self.rpn_cls_score_net = nn.Conv2d(512, self._num_anchors * 2, [1, 1])
269-
270-
self.rpn_bbox_pred_net = nn.Conv2d(512, self._num_anchors * 4, [1, 1])
271-
272-
self.cls_score_net = nn.Linear(2048, self._num_classes)
273-
self.bbox_pred_net = nn.Linear(2048, self._num_classes * 4)
274-
275-
self.init_weights()
276-
277267
def train(self, mode=True):
278268
# Override train so that the training mode is set as we want
279269
nn.Module.train(self, mode)

lib/nets/vgg16.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ def __init__(self):
2222
Network.__init__(self)
2323
self._feat_stride = [16, ]
2424
self._feat_compress = [1. / float(self._feat_stride[0]), ]
25+
self._net_conv_channels = 512
26+
self._fc7_channels = 4096
2527

26-
def _init_modules(self):
28+
def _init_head_tail(self):
2729
self.vgg = models.vgg16()
2830
# Remove fc8
2931
self.vgg.classifier = nn.Sequential(*list(self.vgg.classifier._modules.values())[:-1])
@@ -35,18 +37,6 @@ def _init_modules(self):
3537
# not using the last maxpool layer
3638
self._layers['head'] = nn.Sequential(*list(self.vgg.features._modules.values())[:-1])
3739

38-
# rpn
39-
self.rpn_net = nn.Conv2d(512, 512, [3, 3], padding=1)
40-
41-
self.rpn_cls_score_net = nn.Conv2d(512, self._num_anchors * 2, [1, 1])
42-
43-
self.rpn_bbox_pred_net = nn.Conv2d(512, self._num_anchors * 4, [1, 1])
44-
45-
self.cls_score_net = nn.Linear(4096, self._num_classes)
46-
self.bbox_pred_net = nn.Linear(4096, self._num_classes * 4)
47-
48-
self.init_weights()
49-
5040
def _image_to_head(self):
5141
net_conv = self._layers['head'](self._image)
5242
self._act_summaries['conv'] = net_conv

0 commit comments

Comments
 (0)