Skip to content

Commit fa88df8

Browse files
committed
Use crop_and_resize from longcw/RoIAlign.pytorch. Speed up the training speed from 0.315s -> 0.282 (on my machine).
1 parent d5012b0 commit fa88df8

15 files changed

+781
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Additional features not mentioned in the [report](https://arxiv.org/pdf/1702.021
6969
git clone https://github.com/ruotianluo/pytorch-faster-rcnn.git
7070
```
7171

72-
2. Compile modules:
72+
2. Compile modules(nms, roi_pooling(from [longcw/yolo2-pytorch](https://github.com/longcw/yolo2-pytorch)), roi_align(from [longcw/RoIAlign.pytorch](https://github.com/longcw/RoIAlign.pytorch.git))):
7373
```
7474
cd pytorch-faster-rcnn/lib
7575
bash make.sh

lib/layer_utils/roi_align/__init__.py

Whitespace-only changes.

lib/layer_utils/roi_align/_ext/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
from torch.utils.ffi import _wrap_function
3+
from ._crop_and_resize import lib as _lib, ffi as _ffi
4+
5+
__all__ = []
6+
def _import_symbols(locals):
7+
for symbol in dir(_lib):
8+
fn = getattr(_lib, symbol)
9+
if callable(fn):
10+
locals[symbol] = _wrap_function(fn, _ffi)
11+
else:
12+
locals[symbol] = fn
13+
__all__.append(symbol)
14+
15+
_import_symbols(locals())

lib/layer_utils/roi_align/build.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import torch
3+
from torch.utils.ffi import create_extension
4+
5+
6+
sources = ['src/crop_and_resize.c']
7+
headers = ['src/crop_and_resize.h']
8+
defines = []
9+
with_cuda = False
10+
11+
extra_objects = []
12+
if torch.cuda.is_available():
13+
print('Including CUDA code.')
14+
sources += ['src/crop_and_resize_gpu.c']
15+
headers += ['src/crop_and_resize_gpu.h']
16+
defines += [('WITH_CUDA', None)]
17+
extra_objects += ['src/cuda/crop_and_resize_kernel.cu.o']
18+
with_cuda = True
19+
20+
extra_compile_args = ['-std=c99']
21+
22+
this_file = os.path.dirname(os.path.realpath(__file__))
23+
print(this_file)
24+
sources = [os.path.join(this_file, fname) for fname in sources]
25+
headers = [os.path.join(this_file, fname) for fname in headers]
26+
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
27+
28+
ffi = create_extension(
29+
'_ext.crop_and_resize',
30+
headers=headers,
31+
sources=sources,
32+
define_macros=defines,
33+
relative_to=__file__,
34+
with_cuda=with_cuda,
35+
extra_objects=extra_objects,
36+
extra_compile_args=extra_compile_args
37+
)
38+
39+
if __name__ == '__main__':
40+
ffi.build()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from torch.autograd import Function
6+
7+
from ._ext import crop_and_resize as _backend
8+
9+
10+
class CropAndResizeFunction(Function):
11+
12+
def __init__(self, crop_height, crop_width, extrapolation_value=0):
13+
self.crop_height = crop_height
14+
self.crop_width = crop_width
15+
self.extrapolation_value = extrapolation_value
16+
17+
def forward(self, image, boxes, box_ind):
18+
crops = torch.zeros_like(image)
19+
20+
if image.is_cuda:
21+
_backend.crop_and_resize_gpu_forward(
22+
image, boxes, box_ind,
23+
self.extrapolation_value, self.crop_height, self.crop_width, crops)
24+
else:
25+
_backend.crop_and_resize_forward(
26+
image, boxes, box_ind,
27+
self.extrapolation_value, self.crop_height, self.crop_width, crops)
28+
29+
# save for backward
30+
self.im_size = image.size()
31+
self.save_for_backward(boxes, box_ind)
32+
33+
return crops
34+
35+
def backward(self, grad_outputs):
36+
boxes, box_ind = self.saved_tensors
37+
38+
grad_outputs = grad_outputs.contiguous()
39+
grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size)
40+
41+
if grad_outputs.is_cuda:
42+
_backend.crop_and_resize_gpu_backward(
43+
grad_outputs, boxes, box_ind, grad_image
44+
)
45+
else:
46+
_backend.crop_and_resize_backward(
47+
grad_outputs, boxes, box_ind, grad_image
48+
)
49+
50+
return grad_image, None, None
51+
52+
53+
class CropAndResize(nn.Module):
54+
"""
55+
Crop and resize ported from tensorflow
56+
See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize
57+
"""
58+
59+
def __init__(self, crop_height, crop_width, extrapolation_value=0):
60+
super(CropAndResize, self).__init__()
61+
62+
self.crop_height = crop_height
63+
self.crop_width = crop_width
64+
self.extrapolation_value = extrapolation_value
65+
66+
def forward(self, image, boxes, box_ind):
67+
return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from torch import nn
3+
4+
from .crop_and_resize import CropAndResizeFunction, CropAndResize
5+
6+
7+
class RoIAlign(nn.Module):
8+
9+
def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True):
10+
super(RoIAlign, self).__init__()
11+
12+
self.crop_height = crop_height
13+
self.crop_width = crop_width
14+
self.extrapolation_value = extrapolation_value
15+
self.transform_fpcoor = transform_fpcoor
16+
17+
def forward(self, featuremap, boxes, box_ind):
18+
"""
19+
RoIAlign based on crop_and_resize.
20+
See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301
21+
:param featuremap: NxCxHxW
22+
:param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization**
23+
:param box_ind: M
24+
:return: MxCxoHxoW
25+
"""
26+
x1, y1, x2, y2 = torch.split(boxes, 1, dim=1)
27+
image_height, image_width = featuremap.size()[2:4]
28+
29+
if self.transform_fpcoor:
30+
spacing_w = (x2 - x1) / float(self.crop_width)
31+
spacing_h = (y2 - y1) / float(self.crop_height)
32+
33+
nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1)
34+
ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1)
35+
nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1)
36+
nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1)
37+
38+
boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1)
39+
else:
40+
x1 = x1 / float(image_width - 1)
41+
x2 = x2 / float(image_width - 1)
42+
y1 = y1 / float(image_height - 1)
43+
y2 = y2 / float(image_height - 1)
44+
boxes = torch.cat((y1, x1, y2, x2), 1)
45+
46+
boxes = boxes.detach().contiguous()
47+
box_ind = box_ind.detach()
48+
return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind)

0 commit comments

Comments
 (0)