Skip to content

Commit bf770f5

Browse files
author
Ruotian Luo
committed
Make changes for mobilenet. (Training is still not working)
1 parent 7fd5263 commit bf770f5

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Additional features not mentioned in the [report](https://arxiv.org/pdf/1702.021
8888
### Setup data
8989
Please follow the instructions of py-faster-rcnn [here](https://github.com/rbgirshick/py-faster-rcnn#beyond-the-demo-installation-for-training-and-testing-models) to setup VOC and COCO datasets (Part of COCO is done). The steps involve downloading data and optionally creating soft links in the ``data`` folder. Since faster RCNN does not rely on pre-computed proposals, it is safe to ignore the steps that setup proposals.
9090

91-
If you find it useful, the ``data/cache`` folder created on Xinlei's side is also shared [here](http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/cache.tgz).
91+
If you find it useful, the ``data/cache`` folder created on Xinlei's side is also shared [here](https://drive.google.com/drive/folders/0B1_fAEgxdnvJSmF3YUlZcHFqWTQ).
9292

9393
### Demo and Test with pre-trained models
9494
1. Download pre-trained model (only google drive works)
@@ -173,6 +173,15 @@ This script will create a `.pth` file with the same name in the same folder as t
173173
cd ../..
174174
```
175175

176+
For Mobilenet V1, you can set up like:
177+
```Shell
178+
mkdir -p data/imagenet_weights
179+
cd data/imagenet_weights
180+
# download from my gdrive (https://drive.google.com/open?id=0B7fNdx_jAqhtZGJvZlpVeDhUN1k)
181+
mv mobilenet_v1_1.0_224.pth.pth mobile.pth
182+
cd ../..
183+
```
184+
176185
2. Train (and test, evaluation)
177186
```Shell
178187
./experiments/scripts/train_faster_rcnn.sh [GPU_ID] [DATASET] [NET]

lib/model/train_val.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def construct_graph(self):
128128
if 'bias' in key:
129129
params += [{'params':[value],'lr':lr*(cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
130130
else:
131-
params += [{'params':[value],'lr':lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
131+
params += [{'params':[value],'lr':lr, 'weight_decay': getattr(value, 'weight_decay', cfg.TRAIN.WEIGHT_DECAY)}]
132132
self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
133133
# Write the train and validation information to tensorboard
134134
self.writer = tb.writer.FileWriter(self.tbdir)

lib/nets/mobilenet_v1.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,22 +229,18 @@ def set_bn_fix(m):
229229
self.mobilenet.apply(set_bn_fix)
230230

231231
# Add weight decay
232-
def l2_regularizer(m, wd):
232+
def l2_regularizer(m, wd, regu_depth):
233233
if m.__class__.__name__.find('Conv') != -1:
234-
m.weight.weight_decay = cfg.MOBILENET.WEIGHT_DECAY
235-
if cfg.MOBILENET.REGU_DEPTH:
236-
self.mobilenet.apply(lambda x: l2_regularizer(x, cfg.MOBILENET.WEIGHT_DECAY))
237-
else:
238-
self.mobilenet.apply(lambda x: l2_regularizer(x, 0))
239-
# always set the first conv layer
240-
list(self.mobilenet.children())[0].apply(lambda x: l2_regularizer(x, cfg.MOBILENET.WEIGHT_DECAY))
234+
if regu_depth or m.groups == 1:
235+
m.weight.weight_decay = wd
236+
else:
237+
m.weight.weight_decay = 0
238+
self.mobilenet.apply(lambda x: l2_regularizer(x, cfg.MOBILENET.WEIGHT_DECAY, cfg.MOBILENET.REGU_DEPTH))
241239

242240
# Build mobilenet.
243241
self._layers['head'] = nn.Sequential(*list(self.mobilenet.children())[:12])
244242
self._layers['tail'] = nn.Sequential(*list(self.mobilenet.children())[12:])
245243

246244
def load_pretrained_cnn(self, state_dict):
247-
# TODO
248245
print('Warning: No available pretrained model yet')
249-
return
250-
self.mobilenet.load_state_dict({k: state_dict[k] for k in list(self.resnet.state_dict())})
246+
self.mobilenet.load_state_dict({k: state_dict['features.'+k] for k in list(self.mobilenet.state_dict())})

0 commit comments

Comments
 (0)