Skip to content

Commit e6bdac3

Browse files
adityaarun1ruotianluo
authored andcommitted
add PyTorch 1.0 support
1 parent f916db4 commit e6bdac3

28 files changed

+1517
-77
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ lib/build
66
lib/pycocotools
77
lib/pycocotools/_mask.c
88
lib/pycocotools/_mask.so
9+
lib/faster_rcnn.egg-info
910
.idea

experiments/cfgs/mobile.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ TRAIN:
1313
SNAPSHOT_PREFIX: mobile_faster_rcnn
1414
TEST:
1515
HAS_RPN: True
16-
POOLING_MODE: crop
16+
POOLING_MODE: align

experiments/cfgs/res101-lg.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ TEST:
1818
SCALES: [800]
1919
MAX_SIZE: 1333
2020
RPN_POST_NMS_TOP_N: 1000
21-
POOLING_MODE: crop
21+
POOLING_MODE: align
2222
ANCHOR_SCALES: [2,4,8,16,32]

experiments/cfgs/res101.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ TRAIN:
1313
SNAPSHOT_PREFIX: res101_faster_rcnn
1414
TEST:
1515
HAS_RPN: True
16-
POOLING_MODE: crop
16+
POOLING_MODE: align

experiments/cfgs/res50.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ TRAIN:
1313
SNAPSHOT_PREFIX: res50_faster_rcnn
1414
TEST:
1515
HAS_RPN: True
16-
POOLING_MODE: crop
16+
POOLING_MODE: align

experiments/cfgs/vgg16.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ TRAIN:
1212
SNAPSHOT_PREFIX: vgg16_faster_rcnn
1313
TEST:
1414
HAS_RPN: True
15-
POOLING_MODE: crop
15+
POOLING_MODE: align

lib/layer_utils/csrc/ROIAlign.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
#pragma once
3+
4+
#include "cpu/vision.h"
5+
6+
#ifdef WITH_CUDA
7+
#include "cuda/vision.h"
8+
#endif
9+
10+
// Interface for Python
11+
at::Tensor ROIAlign_forward(const at::Tensor& input,
12+
const at::Tensor& rois,
13+
const float spatial_scale,
14+
const int pooled_height,
15+
const int pooled_width,
16+
const int sampling_ratio) {
17+
if (input.type().is_cuda()) {
18+
#ifdef WITH_CUDA
19+
return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
20+
#else
21+
AT_ERROR("Not compiled with GPU support");
22+
#endif
23+
}
24+
return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
25+
}
26+
27+
at::Tensor ROIAlign_backward(const at::Tensor& grad,
28+
const at::Tensor& rois,
29+
const float spatial_scale,
30+
const int pooled_height,
31+
const int pooled_width,
32+
const int batch_size,
33+
const int channels,
34+
const int height,
35+
const int width,
36+
const int sampling_ratio) {
37+
if (grad.type().is_cuda()) {
38+
#ifdef WITH_CUDA
39+
return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
40+
#else
41+
AT_ERROR("Not compiled with GPU support");
42+
#endif
43+
}
44+
AT_ERROR("Not implemented on the CPU");
45+
}
46+

lib/layer_utils/csrc/ROIPool.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
#pragma once
3+
4+
#include "cpu/vision.h"
5+
6+
#ifdef WITH_CUDA
7+
#include "cuda/vision.h"
8+
#endif
9+
10+
11+
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor& input,
12+
const at::Tensor& rois,
13+
const float spatial_scale,
14+
const int pooled_height,
15+
const int pooled_width) {
16+
if (input.type().is_cuda()) {
17+
#ifdef WITH_CUDA
18+
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
19+
#else
20+
AT_ERROR("Not compiled with GPU support");
21+
#endif
22+
}
23+
AT_ERROR("Not implemented on the CPU");
24+
}
25+
26+
at::Tensor ROIPool_backward(const at::Tensor& grad,
27+
const at::Tensor& input,
28+
const at::Tensor& rois,
29+
const at::Tensor& argmax,
30+
const float spatial_scale,
31+
const int pooled_height,
32+
const int pooled_width,
33+
const int batch_size,
34+
const int channels,
35+
const int height,
36+
const int width) {
37+
if (grad.type().is_cuda()) {
38+
#ifdef WITH_CUDA
39+
return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
40+
#else
41+
AT_ERROR("Not compiled with GPU support");
42+
#endif
43+
}
44+
AT_ERROR("Not implemented on the CPU");
45+
}
46+
47+
48+
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
#include "cpu/vision.h"
3+
4+
// implementation taken from Caffe2
5+
template <typename T>
6+
struct PreCalc {
7+
int pos1;
8+
int pos2;
9+
int pos3;
10+
int pos4;
11+
T w1;
12+
T w2;
13+
T w3;
14+
T w4;
15+
};
16+
17+
template <typename T>
18+
void pre_calc_for_bilinear_interpolate(
19+
const int height,
20+
const int width,
21+
const int pooled_height,
22+
const int pooled_width,
23+
const int iy_upper,
24+
const int ix_upper,
25+
T roi_start_h,
26+
T roi_start_w,
27+
T bin_size_h,
28+
T bin_size_w,
29+
int roi_bin_grid_h,
30+
int roi_bin_grid_w,
31+
std::vector<PreCalc<T>>& pre_calc) {
32+
int pre_calc_index = 0;
33+
for (int ph = 0; ph < pooled_height; ph++) {
34+
for (int pw = 0; pw < pooled_width; pw++) {
35+
for (int iy = 0; iy < iy_upper; iy++) {
36+
const T yy = roi_start_h + ph * bin_size_h +
37+
static_cast<T>(iy + .5f) * bin_size_h /
38+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
39+
for (int ix = 0; ix < ix_upper; ix++) {
40+
const T xx = roi_start_w + pw * bin_size_w +
41+
static_cast<T>(ix + .5f) * bin_size_w /
42+
static_cast<T>(roi_bin_grid_w);
43+
44+
T x = xx;
45+
T y = yy;
46+
// deal with: inverse elements are out of feature map boundary
47+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
48+
// empty
49+
PreCalc<T> pc;
50+
pc.pos1 = 0;
51+
pc.pos2 = 0;
52+
pc.pos3 = 0;
53+
pc.pos4 = 0;
54+
pc.w1 = 0;
55+
pc.w2 = 0;
56+
pc.w3 = 0;
57+
pc.w4 = 0;
58+
pre_calc[pre_calc_index] = pc;
59+
pre_calc_index += 1;
60+
continue;
61+
}
62+
63+
if (y <= 0) {
64+
y = 0;
65+
}
66+
if (x <= 0) {
67+
x = 0;
68+
}
69+
70+
int y_low = (int)y;
71+
int x_low = (int)x;
72+
int y_high;
73+
int x_high;
74+
75+
if (y_low >= height - 1) {
76+
y_high = y_low = height - 1;
77+
y = (T)y_low;
78+
} else {
79+
y_high = y_low + 1;
80+
}
81+
82+
if (x_low >= width - 1) {
83+
x_high = x_low = width - 1;
84+
x = (T)x_low;
85+
} else {
86+
x_high = x_low + 1;
87+
}
88+
89+
T ly = y - y_low;
90+
T lx = x - x_low;
91+
T hy = 1. - ly, hx = 1. - lx;
92+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
93+
94+
// save weights and indeces
95+
PreCalc<T> pc;
96+
pc.pos1 = y_low * width + x_low;
97+
pc.pos2 = y_low * width + x_high;
98+
pc.pos3 = y_high * width + x_low;
99+
pc.pos4 = y_high * width + x_high;
100+
pc.w1 = w1;
101+
pc.w2 = w2;
102+
pc.w3 = w3;
103+
pc.w4 = w4;
104+
pre_calc[pre_calc_index] = pc;
105+
106+
pre_calc_index += 1;
107+
}
108+
}
109+
}
110+
}
111+
}
112+
113+
template <typename T>
114+
void ROIAlignForward_cpu_kernel(
115+
const int nthreads,
116+
const T* bottom_data,
117+
const T& spatial_scale,
118+
const int channels,
119+
const int height,
120+
const int width,
121+
const int pooled_height,
122+
const int pooled_width,
123+
const int sampling_ratio,
124+
const T* bottom_rois,
125+
//int roi_cols,
126+
T* top_data) {
127+
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
128+
int roi_cols = 5;
129+
130+
int n_rois = nthreads / channels / pooled_width / pooled_height;
131+
// (n, c, ph, pw) is an element in the pooled output
132+
// can be parallelized using omp
133+
// #pragma omp parallel for num_threads(32)
134+
for (int n = 0; n < n_rois; n++) {
135+
int index_n = n * channels * pooled_width * pooled_height;
136+
137+
// roi could have 4 or 5 columns
138+
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
139+
int roi_batch_ind = 0;
140+
if (roi_cols == 5) {
141+
roi_batch_ind = offset_bottom_rois[0];
142+
offset_bottom_rois++;
143+
}
144+
145+
// Do not using rounding; this implementation detail is critical
146+
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
147+
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
148+
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
149+
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
150+
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
151+
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
152+
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
153+
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
154+
155+
// Force malformed ROIs to be 1x1
156+
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
157+
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
158+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
159+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
160+
161+
// We use roi_bin_grid to sample the grid and mimic integral
162+
int roi_bin_grid_h = (sampling_ratio > 0)
163+
? sampling_ratio
164+
: ceil(roi_height / pooled_height); // e.g., = 2
165+
int roi_bin_grid_w =
166+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
167+
168+
// We do average (integral) pooling inside a bin
169+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
170+
171+
// we want to precalculate indeces and weights shared by all chanels,
172+
// this is the key point of optimiation
173+
std::vector<PreCalc<T>> pre_calc(
174+
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
175+
pre_calc_for_bilinear_interpolate(
176+
height,
177+
width,
178+
pooled_height,
179+
pooled_width,
180+
roi_bin_grid_h,
181+
roi_bin_grid_w,
182+
roi_start_h,
183+
roi_start_w,
184+
bin_size_h,
185+
bin_size_w,
186+
roi_bin_grid_h,
187+
roi_bin_grid_w,
188+
pre_calc);
189+
190+
for (int c = 0; c < channels; c++) {
191+
int index_n_c = index_n + c * pooled_width * pooled_height;
192+
const T* offset_bottom_data =
193+
bottom_data + (roi_batch_ind * channels + c) * height * width;
194+
int pre_calc_index = 0;
195+
196+
for (int ph = 0; ph < pooled_height; ph++) {
197+
for (int pw = 0; pw < pooled_width; pw++) {
198+
int index = index_n_c + ph * pooled_width + pw;
199+
200+
T output_val = 0.;
201+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
202+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
203+
PreCalc<T> pc = pre_calc[pre_calc_index];
204+
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
205+
pc.w2 * offset_bottom_data[pc.pos2] +
206+
pc.w3 * offset_bottom_data[pc.pos3] +
207+
pc.w4 * offset_bottom_data[pc.pos4];
208+
209+
pre_calc_index += 1;
210+
}
211+
}
212+
output_val /= count;
213+
214+
top_data[index] = output_val;
215+
} // for pw
216+
} // for ph
217+
} // for c
218+
} // for n
219+
}
220+
221+
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
222+
const at::Tensor& rois,
223+
const float spatial_scale,
224+
const int pooled_height,
225+
const int pooled_width,
226+
const int sampling_ratio) {
227+
AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor");
228+
AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor");
229+
230+
auto num_rois = rois.size(0);
231+
auto channels = input.size(1);
232+
auto height = input.size(2);
233+
auto width = input.size(3);
234+
235+
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
236+
auto output_size = num_rois * pooled_height * pooled_width * channels;
237+
238+
if (output.numel() == 0) {
239+
return output;
240+
}
241+
242+
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
243+
ROIAlignForward_cpu_kernel<scalar_t>(
244+
output_size,
245+
input.data<scalar_t>(),
246+
spatial_scale,
247+
channels,
248+
height,
249+
width,
250+
pooled_height,
251+
pooled_width,
252+
sampling_ratio,
253+
rois.data<scalar_t>(),
254+
output.data<scalar_t>());
255+
});
256+
return output;
257+
}

0 commit comments

Comments
 (0)