Skip to content

Commit 3d484ec

Browse files
lantigasoumith
authored andcommitted
Add 3D upsampling (nearest and trilinear) with tests
1 parent a9c4d64 commit 3d484ec

File tree

5 files changed

+587
-0
lines changed

5 files changed

+587
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "THCUNN.h"
2+
#include "common.h"
3+
4+
#include <thrust/transform.h>
5+
#include <thrust/reduce.h>
6+
#include <thrust/transform_reduce.h>
7+
#include <thrust/functional.h>
8+
9+
#include "THCHalf.h"
10+
#include "THCHalfAutoNumerics.cuh"
11+
12+
/*
13+
* Description:
14+
*/
15+
16+
__device__ int translate_idx(int ii, int d1, int d2, int d3, int d4, int scale_factor)
17+
{
18+
int x, y, z, w, v;
19+
v = ii % d4;
20+
ii = ii/d4;
21+
w = ii % d3;
22+
ii = ii/d3;
23+
z = ii % d2;
24+
ii = ii/d2;
25+
y = ii % d1;
26+
ii = ii/d1;
27+
x = ii;
28+
v = v/scale_factor;
29+
w = w/scale_factor;
30+
z = z/scale_factor;
31+
d2 /= scale_factor;
32+
d3 /= scale_factor;
33+
d4 /= scale_factor;
34+
return ((((x*d1+y)*d2)+z)*d3+w)*d4+v;
35+
36+
}
37+
__device__ int translate_idx_inv(int ii, int d1, int d2, int d3, int d4, int scale_factor, int off_x, int off_y, int off_z)
38+
{
39+
int x, y, z, w, v;
40+
v = ii % d4;
41+
ii = ii/d4;
42+
w = ii % d3;
43+
ii = ii/d3;
44+
z = ii % d2;
45+
ii = ii/d2;
46+
y = ii % d1;
47+
ii = ii/d1;
48+
x = ii;
49+
v = v*scale_factor+off_x;
50+
w = w*scale_factor+off_y;
51+
z = z*scale_factor+off_z;
52+
d2 *= scale_factor;
53+
d3 *= scale_factor;
54+
d4 *= scale_factor;
55+
return ((((x*d1+y)*d2)+z)*d3+w)*d4+v;
56+
57+
}
58+
59+
template <typename Dtype>
60+
__global__ void vupscale(Dtype *input, Dtype *output, long no_elements,
61+
int scale_factor, int d1, int d2, int d3, int d4)
62+
{
63+
// output offset:
64+
long ii = threadIdx.x + blockDim.x * blockIdx.x;
65+
ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
66+
if (ii >= no_elements) return;
67+
int ipidx = translate_idx(ii, d1, d2, d3, d4, scale_factor);
68+
output[ii]=input[ipidx];
69+
}
70+
71+
/*
72+
* Description:
73+
*/
74+
template <typename Dtype, typename Acctype>
75+
__global__ void vdownscale(Dtype *gradInput_data, Dtype *gradOutput_data, long no_elements,
76+
int scale_factor, int d1, int d2, int d3, int d4)
77+
{
78+
// output offset:
79+
long ii = threadIdx.x + blockDim.x * blockIdx.x;
80+
ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
81+
if (ii >= no_elements) return;
82+
Acctype sum = Acctype(0);
83+
for (int i=0; i < scale_factor; i++){
84+
for(int j=0; j < scale_factor; j++){
85+
for(int k=0; k < scale_factor; k++){
86+
int ipidx = translate_idx_inv(ii, d1, d2, d3, d4, scale_factor, i, j, k);
87+
sum += gradOutput_data[ipidx];
88+
}
89+
}
90+
}
91+
gradInput_data[ii] += ScalarConvert<Acctype, Dtype>::to(sum);
92+
}
93+
94+
#include "generic/VolumetricUpSamplingNearest.cu"
95+
#include "THCGenerateFloatTypes.h"
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Adapted from interp.cpp from Caffe util by Pauline Luc
2+
// Originally developed by George Papandreou
3+
#include "THCUNN.h"
4+
#include "common.h"
5+
#include "THCDeviceTensor.cuh"
6+
#include "THCDeviceTensorUtils.cuh"
7+
#include "THCDeviceUtils.cuh"
8+
#include "THCHalf.h"
9+
#include "THCHalfAutoNumerics.cuh"
10+
#include "THCAtomics.cuh"
11+
12+
template<typename Dtype, typename Acctype>
13+
__global__ void caffe_gpu_interp2_kernel(const int n,
14+
const Acctype rdepth, const Acctype rheight, const Acctype rwidth,
15+
const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
16+
int index = threadIdx.x + blockIdx.x * blockDim.x;
17+
const int batchsize = data1.getSize(0);
18+
const int channels = data1.getSize(1);
19+
const int depth1 = data1.getSize(2);
20+
const int height1 = data1.getSize(3);
21+
const int width1 = data1.getSize(4);
22+
const int depth2 = data2.getSize(2);
23+
const int height2 = data2.getSize(3);
24+
const int width2 = data2.getSize(4);
25+
26+
if (index < n) {
27+
const int w2 = (index % (height2*width2)) % width2; // 0:width2-1
28+
const int h2 = (index % (height2*width2)) / width2; // 0:height2-1
29+
const int t2 = index / (height2*width2); // 0:depth2-1
30+
// special case: just copy
31+
if (depth1 == depth2 && height1 == height2 && width1 == width2) {
32+
const int t1 = t2;
33+
const int h1 = h2;
34+
const int w1 = w2;
35+
for (int n = 0; n < batchsize ; n++){
36+
for (int c = 0; c < channels; ++c) {
37+
const Dtype val = data1[n][c][t1][h1][w1];
38+
data2[n][c][t2][h2][w2] = val;
39+
}
40+
}
41+
return;
42+
}
43+
//
44+
const Acctype t1r = rdepth * t2;
45+
const int t1 = t1r;
46+
const int t1p = (t1 < depth1 - 1) ? 1 : 0;
47+
const Acctype t1lambda = t1r - t1;
48+
const Acctype t0lambda = Acctype(1) - t1lambda;
49+
//
50+
const Acctype h1r = rheight * h2;
51+
const int h1 = h1r;
52+
const int h1p = (h1 < height1 - 1) ? 1 : 0;
53+
const Acctype h1lambda = h1r - h1;
54+
const Acctype h0lambda = Acctype(1) - h1lambda;
55+
//
56+
const Acctype w1r = rwidth * w2;
57+
const int w1 = w1r;
58+
const int w1p = (w1 < width1 - 1) ? 1 : 0;
59+
const Acctype w1lambda = w1r - w1;
60+
const Acctype w0lambda = Acctype(1) - w1lambda;
61+
//
62+
for (int n = 0; n < batchsize ; n++){
63+
for (int c = 0; c < channels; ++c) {
64+
const Acctype val = t0lambda * (h0lambda * (w0lambda * data1[n][c][t1][h1][w1]
65+
+ w1lambda * data1[n][c][t1][h1][w1+w1p])
66+
+ h1lambda * (w0lambda * data1[n][c][t1][h1+h1p][w1]
67+
+ w1lambda * data1[n][c][t1][h1+h1p][w1+w1p]))
68+
+ t1lambda * (h0lambda * (w0lambda * data1[n][c][t1+t1p][h1][w1]
69+
+ w1lambda * data1[n][c][t1+t1p][h1][w1+w1p])
70+
+ h1lambda * (w0lambda * data1[n][c][t1+t1p][h1+h1p][w1]
71+
+ w1lambda * data1[n][c][t1+t1p][h1+h1p][w1+w1p]));
72+
data2[n][c][t2][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
73+
}
74+
}
75+
}
76+
}
77+
78+
// Backward (adjoint) operation 1 <- 2 (accumulates)
79+
template <typename Dtype, typename Acctype>
80+
__global__ void caffe_gpu_interp2_kernel_backward(const int n,
81+
const Acctype rdepth, const Acctype rheight, const Acctype rwidth,
82+
THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
83+
int index = threadIdx.x + blockIdx.x * blockDim.x;
84+
const int batchsize = data1.getSize(0);
85+
const int channels = data1.getSize(1);
86+
const int depth1 = data1.getSize(2);
87+
const int height1 = data1.getSize(3);
88+
const int width1 = data1.getSize(4);
89+
const int depth2 = data2.getSize(2);
90+
const int height2 = data2.getSize(3);
91+
const int width2 = data2.getSize(4);
92+
if (index < n) {
93+
const int w2 = (index % (height2*width2)) % width2; // 0:width2-1
94+
const int h2 = (index % (height2*width2)) / width2; // 0:height2-1
95+
const int t2 = index / (height2*width2); // 0:depth2-1
96+
// special case: just copy
97+
if (depth1 == depth2 && height1 == height2 && width1 == width2) {
98+
const int t1 = t2;
99+
const int h1 = h2;
100+
const int w1 = w2;
101+
for (int n = 0; n < batchsize ; n++){
102+
for (int c = 0; c < channels; ++c) {
103+
const Dtype val = data2[n][c][t1][h1][w1];
104+
data1[n][c][t2][h2][w2] += val;
105+
}
106+
}
107+
return;
108+
}
109+
//
110+
const Acctype t1r = rdepth * t2;
111+
const int t1 = t1r;
112+
const int t1p = (t1 < depth1 - 1) ? 1 : 0;
113+
const Acctype t1lambda = t1r - t1;
114+
const Acctype t0lambda = Acctype(1) - t1lambda;
115+
//
116+
const Acctype h1r = rheight * h2;
117+
const int h1 = h1r;
118+
const int h1p = (h1 < height1 - 1) ? 1 : 0;
119+
const Acctype h1lambda = h1r - h1;
120+
const Acctype h0lambda = Acctype(1) - h1lambda;
121+
//
122+
const Acctype w1r = rwidth * w2;
123+
const int w1 = w1r;
124+
const int w1p = (w1 < width1 - 1) ? 1 : 0;
125+
const Acctype w1lambda = w1r - w1;
126+
const Acctype w0lambda = Acctype(1) - w1lambda;
127+
//
128+
for (int n = 0; n < batchsize ; n++){
129+
for (int c = 0; c < channels; ++c) {
130+
const Dtype d2val = data2[n][c][t2][h2][w2];
131+
atomicAdd(data1[n][c][t1][h1][w1].data(),
132+
ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w0lambda * d2val));
133+
atomicAdd(data1[n][c][t1][h1][w1+w1p].data(),
134+
ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w1lambda * d2val));
135+
atomicAdd(data1[n][c][t1][h1+h1p][w1].data(),
136+
ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w0lambda * d2val));
137+
atomicAdd(data1[n][c][t1][h1+h1p][w1+w1p].data(),
138+
ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w1lambda * d2val));
139+
atomicAdd(data1[n][c][t1+t1p][h1][w1].data(),
140+
ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w0lambda * d2val));
141+
atomicAdd(data1[n][c][t1+t1p][h1][w1+w1p].data(),
142+
ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w1lambda * d2val));
143+
atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1].data(),
144+
ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w0lambda * d2val));
145+
atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1+w1p].data(),
146+
ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w1lambda * d2val));
147+
}
148+
}
149+
}
150+
/////////////////////////////////////////////////////////
151+
}
152+
153+
154+
#include "generic/VolumetricUpSamplingTrilinear.cu"
155+
#include "THCGenerateFloatTypes.h"

lib/THCUNN/generic/THCUNN.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,4 +1370,38 @@ TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)(
13701370
int ptop, int pbottom,
13711371
int pfront, int pback);
13721372

1373+
TH_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
1374+
THCState *state,
1375+
THCTensor *input,
1376+
THCTensor *gradOutput,
1377+
THCTensor *gradInput,
1378+
int scale_factor);
1379+
1380+
TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)(
1381+
THCState *state,
1382+
THCTensor *input,
1383+
THCTensor *output,
1384+
int scale_factor);
1385+
1386+
TH_API void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
1387+
THCState *state,
1388+
THCTensor *input,
1389+
THCTensor *output,
1390+
int outputDepth,
1391+
int outputHeight,
1392+
int outputWidth);
1393+
1394+
TH_API void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
1395+
THCState *state,
1396+
THCTensor *gradOutput,
1397+
THCTensor *gradInput,
1398+
int nbatch,
1399+
int nchannels,
1400+
int inputDepth,
1401+
int inputHeight,
1402+
int inputWidth,
1403+
int outputDepth,
1404+
int outputHeight,
1405+
int outputWidth);
1406+
13731407
#endif

0 commit comments

Comments
 (0)