Skip to content

Commit f916db4

Browse files
author
Ruotian Luo
committed
Fix the mobilenet bug. Support Mobilenet training now.
1 parent bf770f5 commit f916db4

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ With Resnet101 (last ``conv4``):
2424
- Train on COCO 2014 trainval35k and test on minival (800k/1190k), **35.1**(from scratch) **35.4**(converted) (**35.4** for tf-faster-rcnn).
2525

2626
More Results:
27-
- Train Mobilenet (1.0, 224) on COCO 2014 trainval35k and test on minival (900k/1190k), **21.9**(converted) (**21.8** for tf-faster-rcnn).
27+
- Train Mobilenet (1.0, 224) on COCO 2014 trainval35k and test on minival (900k/1190k), **21.4**(from scratch), **21.9**(converted) (**21.8** for tf-faster-rcnn).
2828
- Train Resnet50 on COCO 2014 trainval35k and test on minival (900k/1190k), **32.4**(converted) (**32.4** for tf-faster-rcnn).
2929
- Train Resnet152 on COCO 2014 trainval35k and test on minival (900k/1190k), **36.7**(converted) (**36.1** for tf-faster-rcnn).
3030

lib/nets/mobilenet_v1.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,22 @@ def l2_regularizer(m, wd, regu_depth):
241241
self._layers['head'] = nn.Sequential(*list(self.mobilenet.children())[:12])
242242
self._layers['tail'] = nn.Sequential(*list(self.mobilenet.children())[12:])
243243

244+
def train(self, mode=True):
245+
# Override train so that the training mode is set as we want
246+
nn.Module.train(self, mode)
247+
if mode:
248+
# Set fixed blocks to be in eval mode (not really doing anything)
249+
for m in list(self.mobilenet.children())[:cfg.MOBILENET.FIXED_LAYERS]:
250+
m.eval()
251+
252+
# Set batchnorm always in eval mode during training
253+
def set_bn_eval(m):
254+
classname = m.__class__.__name__
255+
if classname.find('BatchNorm') != -1:
256+
m.eval()
257+
258+
self.mobilenet.apply(set_bn_eval)
259+
244260
def load_pretrained_cnn(self, state_dict):
245261
print('Warning: No available pretrained model yet')
246262
self.mobilenet.load_state_dict({k: state_dict['features.'+k] for k in list(self.mobilenet.state_dict())})

0 commit comments

Comments
 (0)