Skip to content

Commit ec1d91b

Browse files
committed
Move {SpatialConvolutionMM, Spatial(Adaptive,Average,Max)Pooling} to lib/TCUNN
1 parent 641522a commit ec1d91b

File tree

4 files changed

+1287
-0
lines changed

4 files changed

+1287
-0
lines changed

SpatialAdaptiveMaxPooling.cu

Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
#include "utils.h"
2+
3+
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
4+
5+
/*
6+
* Description:
7+
* this function adaptively maxpools an input 4D tensor along dimensions 2 and 3
8+
* 4D input, 4D output, 4D argmax x and y
9+
*/
10+
__global__ void adaptivemaxpool(float *input, float *output, float *indices_x, float *indices_y,
11+
int input_n, int input_h, int input_w,
12+
int output_h, int output_w,
13+
int strideh, int stridew,
14+
int strided)
15+
{
16+
// iterators
17+
int xx, yy;
18+
19+
// compute offsets based on thread/block ID
20+
int o = blockIdx.x;
21+
int i = o;
22+
//int k = blockIdx.x % input_n;
23+
24+
int xx_start = threadIdx.x;
25+
int xx_end = output_w;
26+
const int xx_step = blockDim.x;
27+
28+
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
29+
int yy_end = output_h;
30+
const int yy_step = blockDim.y*gridDim.y;
31+
32+
// select input/output plane
33+
output = output + o*output_w*output_h;
34+
input = input + i*strided;
35+
indices_x = indices_x + o*output_w*output_h;
36+
indices_y = indices_y + o*output_w*output_h;
37+
38+
// For all output pixels...
39+
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
40+
41+
int y_start = (int)floor(float(yy) / output_h * input_h);
42+
int y_end = (int)ceil(float(yy+1) / output_h * input_h);
43+
int kH = y_end-y_start;
44+
45+
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
46+
int x_start = (int)floor(float(xx) / output_w * input_w);
47+
int x_end = (int)ceil(float(xx + 1) / output_w * input_w);
48+
49+
int kW = x_end-x_start;
50+
51+
// Compute the mean of the input image...
52+
float *ptr_input = input + y_start*strideh + x_start*stridew;
53+
float *ptr_output = output + yy*output_w + xx;
54+
float *ptr_ind_x = indices_x + yy*output_w + xx;
55+
float *ptr_ind_y = indices_y + yy*output_w + xx;
56+
int argmax_x = -1;
57+
int argmax_y = -1;
58+
float max = -FLT_MAX;
59+
int kx, ky;
60+
for(ky = 0; ky < kH; ky++) {
61+
for(kx = 0; kx < kW; kx++) {
62+
float val = ptr_input[kx*stridew];
63+
if (val > max) {
64+
max = val;
65+
argmax_x = kx;
66+
argmax_y = ky;
67+
}
68+
}
69+
ptr_input += strideh; // next input line
70+
}
71+
// Update output and argmax
72+
*ptr_output = max;
73+
*ptr_ind_x = argmax_x + 1;
74+
*ptr_ind_y = argmax_y + 1;
75+
}
76+
}
77+
}
78+
79+
/*
80+
* Description:
81+
* this function computes the gradInput from weight and gradOutput
82+
*/
83+
__global__ void adaptivemaxgradinput(float *gradInput, float *gradOutput, float *indices_x, float *indices_y,
84+
int input_n, int input_h, int input_w,
85+
int output_h, int output_w)
86+
{
87+
// iterators
88+
int xx, yy;
89+
90+
// compute offsets based on thread/block ID
91+
int o = blockIdx.x;
92+
int i = o;
93+
//int k = blockIdx.x % input_n;
94+
95+
int xx_start = threadIdx.x;
96+
int xx_end = output_w;
97+
int xx_step = blockDim.x;
98+
99+
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
100+
int yy_end = output_h;
101+
int yy_step = blockDim.y*gridDim.y;
102+
103+
// select input/output plane
104+
gradOutput = gradOutput + o*output_w*output_h;
105+
gradInput = gradInput + i*input_w*input_h;
106+
indices_x = indices_x + o*output_w*output_h;
107+
indices_y = indices_y + o*output_w*output_h;
108+
109+
// compute gradInput
110+
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
111+
112+
int y_start = (int)floor(float(yy) / output_h * input_h);
113+
114+
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
115+
116+
int x_start = (int)floor(float(xx) / output_w * input_w);
117+
118+
float *ptr_gradInput = gradInput + y_start*input_w + x_start;
119+
float *ptr_gradOutput = gradOutput + yy*output_w + xx;
120+
float *ptr_ind_x = indices_x + yy*output_w + xx;
121+
float *ptr_ind_y = indices_y + yy*output_w + xx;
122+
float z = *ptr_gradOutput;
123+
124+
int argmax_x = (*ptr_ind_x)-1;
125+
int argmax_y = (*ptr_ind_y)-1;
126+
127+
ptr_gradInput[argmax_x + argmax_y*input_w] += z;
128+
}
129+
}
130+
}
131+
132+
/*
133+
* Description:
134+
* this function computes the gradInput from weight and gradOutput
135+
* when kH != dH or kW != dW (uses atomic add)
136+
*/
137+
__global__ void atomicadaptivemaxgradinput(
138+
float *gradInput, float *gradOutput, float *indices_x, float *indices_y,
139+
int input_n, int input_h, int input_w, int output_h, int output_w
140+
)
141+
{
142+
// iterators
143+
int xx, yy;
144+
145+
// compute offsets based on thread/block ID
146+
int o = blockIdx.x;
147+
int i = o;
148+
149+
int xx_start = threadIdx.x;
150+
int xx_end = output_w;
151+
int xx_step = blockDim.x;
152+
153+
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
154+
int yy_end = output_h;
155+
int yy_step = blockDim.y*gridDim.y;
156+
157+
// select input/output plane
158+
gradOutput = gradOutput + o*output_w*output_h;
159+
gradInput = gradInput + i*input_w*input_h;
160+
indices_x = indices_x + o*output_w*output_h;
161+
indices_y = indices_y + o*output_w*output_h;
162+
163+
// compute gradInput
164+
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
165+
166+
int y_start = (int)floor(float(yy) / output_h * input_h);
167+
168+
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
169+
170+
int x_start = (int)floor(float(xx) / output_w * input_w);
171+
172+
float *ptr_gradInput = gradInput + y_start*input_w + x_start;
173+
float *ptr_gradOutput = gradOutput + yy*output_w + xx;
174+
float *ptr_ind_x = indices_x + yy*output_w + xx;
175+
float *ptr_ind_y = indices_y + yy*output_w + xx;
176+
float z = *ptr_gradOutput;
177+
178+
int argmax_x = (*ptr_ind_x)-1;
179+
int argmax_y = (*ptr_ind_y)-1;
180+
181+
// atomic add since different threads could update same variable
182+
atomicAdd(&(ptr_gradInput[argmax_x + argmax_y*input_w]), z);
183+
}
184+
}
185+
}
186+
187+
static int cunn_SpatialAdaptiveMaxPooling_updateOutput(lua_State *L)
188+
{
189+
THCState *state = getCutorchState(L);
190+
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
191+
192+
long nOutputCols = luaT_getfieldcheckint(L, 1, "W");
193+
long nOutputRows = luaT_getfieldcheckint(L, 1, "H");
194+
195+
THCudaTensor *output = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
196+
THCudaTensor *indices = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
197+
THAssert(THCudaTensor_checkGPU(state, 3, input, output, indices));
198+
199+
float *indices_data;
200+
float *output_data;
201+
float *input_data;
202+
203+
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
204+
205+
if (input->nDimension == 3) {
206+
long nInputCols = input->size[2];
207+
long nInputRows = input->size[1];
208+
long nInputPlane = input->size[0];
209+
210+
long istride_d = input->stride[0];
211+
long istride_h = input->stride[1];
212+
long istride_w = input->stride[2];
213+
214+
input_data = THCudaTensor_data(state, input);
215+
216+
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
217+
THCudaTensor_resize4d(state, indices, 2, nInputPlane, nOutputRows, nOutputCols);
218+
219+
indices_data = THCudaTensor_data(state, indices);
220+
output_data = THCudaTensor_data(state, output);
221+
222+
// cuda blocks & threads:
223+
int yblocks = (int)(16L / nInputPlane);
224+
yblocks = yblocks < 1 ? 1 : yblocks;
225+
dim3 blocks(nInputPlane,yblocks);
226+
dim3 threads(32,8);
227+
228+
// run maxpool kernel
229+
adaptivemaxpool <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data,
230+
indices_data+nInputPlane*nOutputCols*nOutputRows, indices_data,
231+
nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
232+
istride_h, istride_w, istride_d);
233+
234+
} else {
235+
long nInputCols = input->size[3];
236+
long nInputRows = input->size[2];
237+
long nInputPlane = input->size[1];
238+
long nbatch = input->size[0];
239+
240+
long istride_d = input->stride[1];
241+
long istride_h = input->stride[2];
242+
long istride_w = input->stride[3];
243+
244+
input = THCudaTensor_newContiguous(state, input);
245+
input_data = THCudaTensor_data(state, input);
246+
247+
THCudaTensor_resize4d(state, output, nbatch, nInputPlane, nOutputRows, nOutputCols);
248+
THCudaTensor_resize5d(state, indices, 2, nbatch, nInputPlane, nOutputRows, nOutputCols);
249+
250+
indices_data = THCudaTensor_data(state, indices);
251+
output_data = THCudaTensor_data(state, output);
252+
253+
// cuda blocks & threads:
254+
int yblocks = (int)(16L / nInputPlane);
255+
yblocks = yblocks < 1 ? 1 : yblocks;
256+
dim3 blocks(nInputPlane*nbatch,yblocks);
257+
dim3 threads(32,8);
258+
259+
// run maxpool kernel
260+
adaptivemaxpool <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data,
261+
indices_data+nbatch*nInputPlane*nOutputCols*nOutputRows, indices_data,
262+
nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
263+
istride_h, istride_w, istride_d);
264+
// clean
265+
THCudaTensor_free(state, input);
266+
}
267+
268+
// check for errors
269+
cudaError_t err = cudaGetLastError();
270+
if (err != cudaSuccess) {
271+
printf("error in SpatialAdaptiveMaxPooling.updateOutput: %s\n", cudaGetErrorString(err));
272+
THError("aborting");
273+
}
274+
return 1;
275+
}
276+
277+
static int cunn_SpatialAdaptiveMaxPooling_updateGradInput(lua_State *L)
278+
{
279+
THCState *state = getCutorchState(L);
280+
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
281+
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
282+
283+
bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests
284+
285+
THCudaTensor *gradInput = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
286+
THCudaTensor *indices = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
287+
THAssert(THCudaTensor_checkGPU(state, 4, input, indices, gradOutput, gradInput));
288+
289+
float *indices_data;
290+
float *gradInput_data;
291+
float *gradOutput_data;
292+
293+
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
294+
295+
if (input->nDimension == 3) {
296+
long nInputCols = input->size[2];
297+
long nInputRows = input->size[1];
298+
long nInputPlane = input->size[0];
299+
long nOutputCols = gradOutput->size[2];
300+
long nOutputRows = gradOutput->size[1];
301+
302+
//bool atomic = (nInputCols%nOutputCols != 0) || (nInputRows%nOutputRows != 0);
303+
304+
THCudaTensor_resizeAs(state, gradInput, input);
305+
THCudaTensor_zero(state, gradInput);
306+
307+
indices_data = THCudaTensor_data(state, indices);
308+
gradOutput_data = THCudaTensor_data(state, gradOutput);
309+
gradInput_data = THCudaTensor_data(state, gradInput);
310+
311+
// cuda blocks & threads:
312+
int yblocks = (int)(16L / nInputPlane);
313+
yblocks = yblocks < 1 ? 1 : yblocks;
314+
dim3 blocks(nInputPlane,yblocks);
315+
dim3 threads(32,8);
316+
317+
if(atomic)
318+
{
319+
// run updateGradInput kernel, accumulate gradients atomically
320+
atomicadaptivemaxgradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
321+
indices_data+nInputPlane*nOutputCols*nOutputRows, indices_data,
322+
nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
323+
}
324+
else
325+
{
326+
// run updateGradInput kernel
327+
atomicadaptivemaxgradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
328+
indices_data+nInputPlane*nOutputCols*nOutputRows, indices_data,
329+
nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
330+
}
331+
} else {
332+
long nInputCols = input->size[3];
333+
long nInputRows = input->size[2];
334+
long nInputPlane = input->size[1];
335+
long nbatch = input->size[0];
336+
long nOutputCols = gradOutput->size[3];
337+
long nOutputRows = gradOutput->size[2];
338+
339+
//bool atomic = //(nInputCols%nOutputCols != 0) || (nInputRows%nOutputRows != 0);
340+
341+
THCudaTensor_resizeAs(state, gradInput, input);
342+
THCudaTensor_zero(state, gradInput);
343+
344+
indices_data = THCudaTensor_data(state, indices);
345+
gradOutput_data = THCudaTensor_data(state, gradOutput);
346+
gradInput_data = THCudaTensor_data(state, gradInput);
347+
348+
// cuda blocks & threads:
349+
int yblocks = (int)(16L / nInputPlane);
350+
yblocks = yblocks < 1 ? 1 : yblocks;
351+
dim3 blocks(nInputPlane*nbatch,yblocks);
352+
dim3 threads(32,8);
353+
354+
if(atomic)
355+
{
356+
// run updateGradInput kernel, accumulate gradients atomically
357+
atomicadaptivemaxgradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
358+
indices_data+nbatch*nInputPlane*nOutputCols*nOutputRows, indices_data,
359+
nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
360+
}
361+
else
362+
{
363+
// run updateGradInput kernel, accumulate gradients atomically
364+
adaptivemaxgradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
365+
indices_data+nbatch*nInputPlane*nOutputCols*nOutputRows, indices_data,
366+
nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
367+
}
368+
}
369+
370+
// clean
371+
THCudaTensor_free(state,gradOutput);
372+
373+
// check for errors
374+
cudaError_t err = cudaGetLastError();
375+
if (err != cudaSuccess) {
376+
printf("error in SpatialAdaptiveMaxPooling.updateGradInput: %s\n", cudaGetErrorString(err));
377+
THError("aborting");
378+
}
379+
return 1;
380+
}
381+
382+
static const struct luaL_Reg cunn_SpatialAdaptiveMaxPooling__ [] = {
383+
{"SpatialAdaptiveMaxPooling_updateOutput", cunn_SpatialAdaptiveMaxPooling_updateOutput},
384+
{"SpatialAdaptiveMaxPooling_updateGradInput", cunn_SpatialAdaptiveMaxPooling_updateGradInput},
385+
{NULL, NULL}
386+
};
387+
388+
void cunn_SpatialAdaptiveMaxPooling_init(lua_State *L)
389+
{
390+
luaT_pushmetatable(L, "torch.CudaTensor");
391+
luaT_registeratname(L, cunn_SpatialAdaptiveMaxPooling__, "nn");
392+
lua_pop(L,1);
393+
}
394+
395+
#undef CUDA_MAX_THREADS

0 commit comments

Comments
 (0)