Skip to content

Commit 7fd5263

Browse files
committed
Change API to pytorch 0.4; Also, support cpu during demo and test time (Doesn't make sense to support cpu during training.).
1 parent fa88df8 commit 7fd5263

File tree

11 files changed

+104
-140
lines changed

11 files changed

+104
-140
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Additional features not mentioned in the [report](https://arxiv.org/pdf/1702.021
5858
- **Support for visualization**. The current implementation will summarize ground truth boxes, statistics of losses, activations and variables during training, and dump it to a separate folder for tensorboard visualization. The computing graph is also saved for debugging.
5959

6060
### Prerequisites
61-
- A basic pytorch installation. The code follows **0.3**. If you are using old **0.1.12** or **0.2**, you can checkout the corresponding branch.
61+
- A basic pytorch installation. The code follows **0.4**. If you are using old **0.1.12** or **0.2** or **0.3**, you can checkout the corresponding branch.
6262
- Python packages you might not have: `cffi`, `opencv-python`, `easydict` (similar to [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn)). For `easydict` make sure you have the right version. Xinlei uses 1.6.
6363
- [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) to visualize the training and validation curve. Please build from source to use the latest tensorflow-tensorboard.
6464
- ~~Docker users: Since the recent upgrade, the docker image on docker hub (https://hub.docker.com/r/mbuckler/tf-faster-rcnn-deps/) is no longer valid. However, you can still build your own image by using dockerfile located at `docker` folder (cuda 8 version, as it is required by Tensorflow r1.0.) And make sure following Tensorflow installation to install and use nvidia-docker[https://github.com/NVIDIA/nvidia-docker]. Last, after launching the container, you have to build the Cython modules within the running container.~~

lib/layer_utils/proposal_layer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from model.nms_wrapper import nms
1414

1515
import torch
16-
from torch.autograd import Variable
17-
1816

1917
def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride, anchors, num_anchors):
2018
"""A simplified version compared to fast/er RCNN
@@ -50,7 +48,7 @@ def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride,
5048
scores = scores[keep,]
5149

5250
# Only support single image as input
53-
batch_inds = Variable(proposals.data.new(proposals.size(0), 1).zero_())
51+
batch_inds = proposals.new_zeros(proposals.size(0), 1)
5452
blob = torch.cat((batch_inds, proposals), 1)
5553

5654
return blob, scores

lib/layer_utils/proposal_target_layer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717

1818
import torch
19-
from torch.autograd import Variable
2019

2120
def proposal_target_layer(rpn_rois, rpn_scores, gt_boxes, _num_classes):
2221
"""
@@ -31,7 +30,7 @@ def proposal_target_layer(rpn_rois, rpn_scores, gt_boxes, _num_classes):
3130

3231
# Include ground-truth boxes in the set of candidate rois
3332
if cfg.TRAIN.USE_GT:
34-
zeros = rpn_rois.data.new(gt_boxes.shape[0], 1)
33+
zeros = rpn_rois.new_zeros(gt_boxes.shape[0], 1)
3534
all_rois = torch.cat(
3635
(all_rois, torch.cat((zeros, gt_boxes[:, :-1]), 1))
3736
, 0)
@@ -55,7 +54,7 @@ def proposal_target_layer(rpn_rois, rpn_scores, gt_boxes, _num_classes):
5554
bbox_inside_weights = bbox_inside_weights.view(-1, _num_classes * 4)
5655
bbox_outside_weights = (bbox_inside_weights > 0).float()
5756

58-
return rois, roi_scores, labels, Variable(bbox_targets), Variable(bbox_inside_weights), Variable(bbox_outside_weights)
57+
return rois, roi_scores, labels, bbox_targets, bbox_inside_weights, bbox_outside_weights
5958

6059

6160
def _get_bbox_regression_labels(bbox_target_data, num_classes):
@@ -72,8 +71,8 @@ def _get_bbox_regression_labels(bbox_target_data, num_classes):
7271
# Inputs are tensor
7372

7473
clss = bbox_target_data[:, 0]
75-
bbox_targets = clss.new(clss.numel(), 4 * num_classes).zero_()
76-
bbox_inside_weights = clss.new(bbox_targets.shape).zero_()
74+
bbox_targets = clss.new_zeros(clss.numel(), 4 * num_classes)
75+
bbox_inside_weights = clss.new_zeros(bbox_targets.shape)
7776
inds = (clss > 0).nonzero().view(-1)
7877
if inds.numel() > 0:
7978
clss = clss[inds].contiguous().view(-1,1)
@@ -122,17 +121,17 @@ def _sample_rois(all_rois, all_scores, gt_boxes, fg_rois_per_image, rois_per_ima
122121
# Small modification to the original version where we ensure a fixed number of regions are sampled
123122
if fg_inds.numel() > 0 and bg_inds.numel() > 0:
124123
fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())
125-
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image), replace=False)).long().cuda()]
124+
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image), replace=False)).long().to(gt_boxes.device)]
126125
bg_rois_per_image = rois_per_image - fg_rois_per_image
127126
to_replace = bg_inds.numel() < bg_rois_per_image
128-
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(bg_rois_per_image), replace=to_replace)).long().cuda()]
127+
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(bg_rois_per_image), replace=to_replace)).long().to(gt_boxes.device)]
129128
elif fg_inds.numel() > 0:
130129
to_replace = fg_inds.numel() < rois_per_image
131-
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().cuda()]
130+
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().to(gt_boxes.device)]
132131
fg_rois_per_image = rois_per_image
133132
elif bg_inds.numel() > 0:
134133
to_replace = bg_inds.numel() < rois_per_image
135-
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().cuda()]
134+
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().to(gt_boxes.device)]
136135
fg_rois_per_image = 0
137136
else:
138137
import pdb

lib/layer_utils/proposal_top_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, ancho
3030
if length < rpn_top_n:
3131
# Random selection, maybe unnecessary and loses good proposals
3232
# But such case rarely happens
33-
top_inds = torch.from_numpy(npr.choice(length, size=rpn_top_n, replace=True)).long().cuda()
33+
top_inds = torch.from_numpy(npr.choice(length, size=rpn_top_n, replace=True)).long().to(anchors.device)
3434
else:
3535
top_inds = scores.sort(0, descending=True)[1]
3636
top_inds = top_inds[:rpn_top_n]
@@ -50,6 +50,6 @@ def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, ancho
5050
# Output rois blob
5151
# Our RPN implementation only supports a single input image, so all
5252
# batch inds are 0
53-
batch_inds = proposals.data.new(proposals.size(0), 1).zero_()
53+
batch_inds = proposals.new_zeros(proposals.size(0), 1)
5454
blob = torch.cat([batch_inds, proposals], 1)
5555
return blob, scores

lib/layer_utils/roi_pooling/roi_pool_py.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

lib/model/train_val.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def train_model(self, max_iters):
233233
next_stepsize = stepsizes.pop()
234234

235235
self.net.train()
236-
self.net.cuda()
236+
self.net.to(self.net._device)
237237

238238
while iter < max_iters + 1:
239239
# Learning rate

lib/nets/network.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self):
4747
self._event_summaries = {}
4848
self._image_gt_summaries = {}
4949
self._variables_to_fix = {}
50+
self._device = 'cuda'
5051

5152
def _add_gt_image(self):
5253
# add back mean
@@ -125,10 +126,10 @@ def _anchor_target_layer(self, rpn_cls_score):
125126
anchor_target_layer(
126127
rpn_cls_score.data, self._gt_boxes.data.cpu().numpy(), self._im_info, self._feat_stride, self._anchors.data.cpu().numpy(), self._num_anchors)
127128

128-
rpn_labels = Variable(torch.from_numpy(rpn_labels).float().cuda()) #.set_shape([1, 1, None, None])
129-
rpn_bbox_targets = Variable(torch.from_numpy(rpn_bbox_targets).float().cuda())#.set_shape([1, None, None, self._num_anchors * 4])
130-
rpn_bbox_inside_weights = Variable(torch.from_numpy(rpn_bbox_inside_weights).float().cuda())#.set_shape([1, None, None, self._num_anchors * 4])
131-
rpn_bbox_outside_weights = Variable(torch.from_numpy(rpn_bbox_outside_weights).float().cuda())#.set_shape([1, None, None, self._num_anchors * 4])
129+
rpn_labels = torch.from_numpy(rpn_labels).float().to(self._device) #.set_shape([1, 1, None, None])
130+
rpn_bbox_targets = torch.from_numpy(rpn_bbox_targets).float().to(self._device)#.set_shape([1, None, None, self._num_anchors * 4])
131+
rpn_bbox_inside_weights = torch.from_numpy(rpn_bbox_inside_weights).float().to(self._device)#.set_shape([1, None, None, self._num_anchors * 4])
132+
rpn_bbox_outside_weights = torch.from_numpy(rpn_bbox_outside_weights).float().to(self._device)#.set_shape([1, None, None, self._num_anchors * 4])
132133

133134
rpn_labels = rpn_labels.long()
134135
self._anchor_targets['rpn_labels'] = rpn_labels
@@ -164,7 +165,7 @@ def _anchor_component(self, height, width):
164165
anchors, anchor_length = generate_anchors_pre(\
165166
height, width,
166167
self._feat_stride, self._anchor_scales, self._anchor_ratios)
167-
self._anchors = Variable(torch.from_numpy(anchors).cuda())
168+
self._anchors = torch.from_numpy(anchors).to(self._device)
168169
self._anchor_length = anchor_length
169170

170171
def _smooth_l1_loss(self, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights, sigma=1.0, dim=[1]):
@@ -186,7 +187,7 @@ def _add_losses(self, sigma_rpn=3.0):
186187
# RPN, class loss
187188
rpn_cls_score = self._predictions['rpn_cls_score_reshape'].view(-1, 2)
188189
rpn_label = self._anchor_targets['rpn_labels'].view(-1)
189-
rpn_select = Variable((rpn_label.data != -1).nonzero().view(-1))
190+
rpn_select = (rpn_label.data != -1).nonzero().view(-1)
190191
rpn_cls_score = rpn_cls_score.index_select(0, rpn_select).contiguous().view(-1, 2)
191192
rpn_label = rpn_label.index_select(0, rpn_select).contiguous().view(-1)
192193
rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)
@@ -325,7 +326,7 @@ def _run_summary_op(self, val=False):
325326
summaries.append(self._add_gt_image_summary())
326327
# Add event_summaries
327328
for key, var in self._event_summaries.items():
328-
summaries.append(tb.summary.scalar(key, var.data[0]))
329+
summaries.append(tb.summary.scalar(key, var.item()))
329330
self._event_summaries = {}
330331
if not val:
331332
# Add score summaries
@@ -375,9 +376,9 @@ def forward(self, image, im_info, gt_boxes=None, mode='TRAIN'):
375376
self._image_gt_summaries['gt_boxes'] = gt_boxes
376377
self._image_gt_summaries['im_info'] = im_info
377378

378-
self._image = Variable(torch.from_numpy(image.transpose([0,3,1,2])).cuda(), volatile=mode == 'TEST')
379+
self._image = torch.from_numpy(image.transpose([0,3,1,2])).to(self._device)
379380
self._im_info = im_info # No need to change; actually it can be an list
380-
self._gt_boxes = Variable(torch.from_numpy(gt_boxes).cuda()) if gt_boxes is not None else None
381+
self._gt_boxes = torch.from_numpy(gt_boxes).to(self._device) if gt_boxes is not None else None
381382

382383
self._mode = mode
383384

@@ -386,7 +387,7 @@ def forward(self, image, im_info, gt_boxes=None, mode='TRAIN'):
386387
if mode == 'TEST':
387388
stds = bbox_pred.data.new(cfg.TRAIN.BBOX_NORMALIZE_STDS).repeat(self._num_classes).unsqueeze(0).expand_as(bbox_pred)
388389
means = bbox_pred.data.new(cfg.TRAIN.BBOX_NORMALIZE_MEANS).repeat(self._num_classes).unsqueeze(0).expand_as(bbox_pred)
389-
self._predictions["bbox_pred"] = bbox_pred.mul(Variable(stds)).add(Variable(means))
390+
self._predictions["bbox_pred"] = bbox_pred.mul(stds).add(means)
390391
else:
391392
self._add_losses() # compute losses
392393

@@ -411,13 +412,14 @@ def normal_init(m, mean, stddev, truncated=False):
411412
# Extract the head feature maps, for example for vgg16 it is conv5_3
412413
# only useful during testing mode
413414
def extract_head(self, image):
414-
feat = self._layers["head"](Variable(torch.from_numpy(image.transpose([0,3,1,2])).cuda(), volatile=True))
415+
feat = self._layers["head"](torch.from_numpy(image.transpose([0,3,1,2])).to(self._device))
415416
return feat
416417

417418
# only useful during testing mode
418419
def test_image(self, image, im_info):
419420
self.eval()
420-
self.forward(image, im_info, None, mode='TEST')
421+
with torch.no_grad():
422+
self.forward(image, im_info, None, mode='TEST')
421423
cls_score, cls_prob, bbox_pred, rois = self._predictions["cls_score"].data.cpu().numpy(), \
422424
self._predictions['cls_prob'].data.cpu().numpy(), \
423425
self._predictions['bbox_pred'].data.cpu().numpy(), \
@@ -440,11 +442,11 @@ def get_summary(self, blobs):
440442

441443
def train_step(self, blobs, train_op):
442444
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
443-
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].data[0], \
444-
self._losses['rpn_loss_box'].data[0], \
445-
self._losses['cross_entropy'].data[0], \
446-
self._losses['loss_box'].data[0], \
447-
self._losses['total_loss'].data[0]
445+
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].item(), \
446+
self._losses['rpn_loss_box'].item(), \
447+
self._losses['cross_entropy'].item(), \
448+
self._losses['loss_box'].item(), \
449+
self._losses['total_loss'].item()
448450
#utils.timer.timer.tic('backward')
449451
train_op.zero_grad()
450452
self._losses['total_loss'].backward()
@@ -457,11 +459,11 @@ def train_step(self, blobs, train_op):
457459

458460
def train_step_with_summary(self, blobs, train_op):
459461
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
460-
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].data[0], \
461-
self._losses['rpn_loss_box'].data[0], \
462-
self._losses['cross_entropy'].data[0], \
463-
self._losses['loss_box'].data[0], \
464-
self._losses['total_loss'].data[0]
462+
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].item(), \
463+
self._losses['rpn_loss_box'].item(), \
464+
self._losses['cross_entropy'].item(), \
465+
self._losses['loss_box'].item(), \
466+
self._losses['total_loss'].item()
465467
train_op.zero_grad()
466468
self._losses['total_loss'].backward()
467469
train_op.step()

lib/utils/timer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ def __init__(self):
2020
def tic(self, name='default'):
2121
# using time.time instead of time.clock because time time.clock
2222
# does not normalize for multithreading
23-
torch.cuda.synchronize()
23+
if torch.cuda.is_available():
24+
torch.cuda.synchronize()
2425
self._start_time[name] = time.time()
2526

2627
def toc(self, name='default', average=True):
27-
torch.cuda.synchronize()
28+
if torch.cuda.is_available():
29+
torch.cuda.synchronize()
2830
self._diff[name] = time.time() - self._start_time[name]
2931
self._total_time[name] = self._total_time.get(name, 0.) + self._diff[name]
3032
self._calls[name] = self._calls.get(name, 0 ) + 1

tools/demo.ipynb

Lines changed: 55 additions & 48 deletions
Large diffs are not rendered by default.

tools/demo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,12 @@ def parse_args():
136136
net.create_architecture(21,
137137
tag='default', anchor_scales=[8, 16, 32])
138138

139-
net.load_state_dict(torch.load(saved_model))
139+
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))
140140

141141
net.eval()
142-
net.cuda()
142+
if not torch.cuda.is_available():
143+
net._device = 'cpu'
144+
net.to(net._device)
143145

144146
print('Loaded network {:s}'.format(saved_model))
145147

tools/test_net.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,13 @@ def parse_args():
104104
anchor_ratios=cfg.ANCHOR_RATIOS)
105105

106106
net.eval()
107-
net.cuda()
107+
if not torch.cuda.is_available():
108+
net._device = 'cpu'
109+
net.to(net._device)
108110

109111
if args.model:
110112
print(('Loading model check point from {:s}').format(args.model))
111-
net.load_state_dict(torch.load(args.model))
113+
net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
112114
print('Loaded.')
113115
else:
114116
print(('Loading initial weights from {:s}').format(args.weight))

0 commit comments

Comments
 (0)