Skip to content

Commit 130ed2c

Browse files
committed
Move SpatialConvolutionMM.c -> lib/THNN/generic
1 parent 297b393 commit 130ed2c

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed

generic/SpatialConvolutionMM.c

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "generic/SpatialConvolutionMM.c"
3+
#else
4+
5+
#ifdef _WIN32
6+
# include <windows.h>
7+
#endif
8+
9+
#include "unfold.h"
10+
11+
12+
static void nn_(SpatialConvolutionMM_updateOutput_frame)(THTensor *input, THTensor *output, THTensor *weight, THTensor *bias, THTensor *finput,
13+
int kW, int kH, int dW, int dH, int padW, int padH,
14+
long nInputPlane, long inputWidth, long inputHeight,
15+
long nOutputPlane, long outputWidth, long outputHeight)
16+
{
17+
long i;
18+
THTensor *output2d;
19+
20+
nn_(unfolded_copy)(finput, input, kW, kH, dW, dH, padW, padH, nInputPlane, inputWidth, inputHeight, outputWidth, outputHeight);
21+
22+
output2d = THTensor_(newWithStorage2d)(output->storage, output->storageOffset,
23+
nOutputPlane, -1,
24+
outputHeight*outputWidth, -1);
25+
26+
for(i = 0; i < nOutputPlane; i++)
27+
THVector_(fill)(output->storage->data+output->storageOffset+output->stride[0]*i, THTensor_(get1d)(bias, i), outputHeight*outputWidth);
28+
29+
THTensor_(addmm)(output2d, 1, output2d, 1, weight, finput);
30+
31+
THTensor_(free)(output2d);
32+
}
33+
34+
static int nn_(SpatialConvolutionMM_updateOutput)(lua_State *L)
35+
{
36+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
37+
int kW = luaT_getfieldcheckint(L, 1, "kW");
38+
int kH = luaT_getfieldcheckint(L, 1, "kH");
39+
int dW = luaT_getfieldcheckint(L, 1, "dW");
40+
int dH = luaT_getfieldcheckint(L, 1, "dH");
41+
int padW = luaT_getfieldcheckint(L, 1, "padW");
42+
int padH = luaT_getfieldcheckint(L, 1, "padH");
43+
44+
THTensor *finput = luaT_getfieldcheckudata(L, 1, "finput", torch_Tensor);
45+
THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
46+
THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
47+
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
48+
49+
int dimf = 0;
50+
int dimw = 2;
51+
int dimh = 1;
52+
53+
long nInputPlane;
54+
long inputWidth;
55+
long inputHeight;
56+
long nOutputPlane;
57+
long outputWidth;
58+
long outputHeight;
59+
60+
luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D(batch mode) tensor expected");
61+
62+
63+
if (input->nDimension == 4) {
64+
dimf++;
65+
dimw++;
66+
dimh++;
67+
}
68+
69+
nInputPlane = input->size[dimf];
70+
inputWidth = input->size[dimw];
71+
inputHeight = input->size[dimh];
72+
nOutputPlane = weight->size[0];
73+
outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
74+
outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
75+
76+
if (outputWidth < 1 || outputHeight < 1)
77+
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
78+
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
79+
80+
if (nInputPlane*kW*kH != weight->size[1])
81+
THError("Wrong number of input channels! Input has %d channels, expected %d",nInputPlane,weight->size[1]/(kW*kH));
82+
83+
if(input->nDimension == 3)
84+
{
85+
THTensor_(resize2d)(finput, kW*kH*nInputPlane, outputHeight*outputWidth);
86+
THTensor_(resize3d)(output, nOutputPlane, outputHeight, outputWidth);
87+
88+
nn_(SpatialConvolutionMM_updateOutput_frame)(input, output, weight, bias, finput,
89+
kW, kH, dW, dH, padW, padH,
90+
nInputPlane, inputWidth, inputHeight,
91+
nOutputPlane, outputWidth, outputHeight);
92+
}
93+
else
94+
{
95+
long T = input->size[0];
96+
long t;
97+
98+
THTensor_(resize3d)(finput, T, kW*kH*nInputPlane, outputHeight*outputWidth);
99+
THTensor_(resize4d)(output, T, nOutputPlane, outputHeight, outputWidth);
100+
101+
#pragma omp parallel for private(t)
102+
for(t = 0; t < T; t++)
103+
{
104+
THTensor *input_t = THTensor_(newSelect)(input, 0, t);
105+
THTensor *output_t = THTensor_(newSelect)(output, 0, t);
106+
THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);
107+
108+
nn_(SpatialConvolutionMM_updateOutput_frame)(input_t, output_t, weight, bias, finput_t,
109+
kW, kH, dW, dH, padW, padH,
110+
nInputPlane, inputWidth, inputHeight,
111+
nOutputPlane, outputWidth, outputHeight);
112+
113+
THTensor_(free)(input_t);
114+
THTensor_(free)(output_t);
115+
THTensor_(free)(finput_t);
116+
}
117+
}
118+
119+
return 1;
120+
}
121+
122+
123+
static void nn_(SpatialConvolutionMM_updateGradInput_frame)(THTensor *gradInput, THTensor *gradOutput, THTensor *weight, THTensor *fgradInput,
124+
int kW, int kH, int dW, int dH, int padW, int padH)
125+
{
126+
THTensor *gradOutput2d = THTensor_(newWithStorage2d)(gradOutput->storage, gradOutput->storageOffset,
127+
gradOutput->size[0], -1,
128+
gradOutput->size[1]*gradOutput->size[2], -1);
129+
THTensor_(addmm)(fgradInput, 0, fgradInput, 1, weight, gradOutput2d);
130+
THTensor_(free)(gradOutput2d);
131+
132+
THTensor_(zero)(gradInput);
133+
134+
nn_(unfolded_acc)(fgradInput, gradInput, kW, kH, dW, dH, padW, padH, gradInput->size[0], gradInput->size[2], gradInput->size[1], gradOutput->size[2], gradOutput->size[1]);
135+
}
136+
137+
static int nn_(SpatialConvolutionMM_updateGradInput)(lua_State *L)
138+
{
139+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
140+
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
141+
int kW = luaT_getfieldcheckint(L, 1, "kW");
142+
int kH = luaT_getfieldcheckint(L, 1, "kH");
143+
int dW = luaT_getfieldcheckint(L, 1, "dW");
144+
int dH = luaT_getfieldcheckint(L, 1, "dH");
145+
int padW = luaT_getfieldcheckint(L, 1, "padW");
146+
int padH = luaT_getfieldcheckint(L, 1, "padH");
147+
int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
148+
149+
THTensor *finput = luaT_getfieldcheckudata(L, 1, "finput", torch_Tensor);
150+
THTensor *fgradInput = luaT_getfieldcheckudata(L, 1, "fgradInput", torch_Tensor);
151+
THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
152+
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
153+
154+
THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 1, "Number of output features is not equal to nOutputPlane" );
155+
156+
THTensor_(resizeAs)(gradInput, input);
157+
THTensor_(resizeAs)(fgradInput, finput);
158+
THTensor_(transpose)(weight, weight, 0, 1);
159+
160+
if(input->nDimension == 3)
161+
{
162+
nn_(SpatialConvolutionMM_updateGradInput_frame)(gradInput, gradOutput, weight, fgradInput, kW, kH, dW, dH, padW, padH);
163+
}
164+
else
165+
{
166+
long T = input->size[0];
167+
long t;
168+
169+
#pragma omp parallel for private(t)
170+
for(t = 0; t < T; t++)
171+
{
172+
THTensor *gradInput_t = THTensor_(newSelect)(gradInput, 0, t);
173+
THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
174+
THTensor *fgradInput_t = THTensor_(newSelect)(fgradInput, 0, t);
175+
176+
nn_(SpatialConvolutionMM_updateGradInput_frame)(gradInput_t, gradOutput_t, weight, fgradInput_t, kW, kH, dW, dH, padW, padH);
177+
178+
THTensor_(free)(gradInput_t);
179+
THTensor_(free)(gradOutput_t);
180+
THTensor_(free)(fgradInput_t);
181+
}
182+
}
183+
184+
THTensor_(transpose)(weight, weight, 0, 1);
185+
186+
return 1;
187+
}
188+
189+
static void nn_(SpatialConvolutionMM_accGradParameters_frame)(THTensor *gradOutput, THTensor *gradWeight, THTensor *gradBias, THTensor *finput,
190+
real scale)
191+
{
192+
long i;
193+
THTensor *gradOutput2d = THTensor_(newWithStorage2d)(gradOutput->storage, gradOutput->storageOffset,
194+
gradOutput->size[0], -1,
195+
gradOutput->size[1]*gradOutput->size[2], -1);
196+
197+
THTensor_(transpose)(finput, finput, 0, 1);
198+
THTensor_(addmm)(gradWeight, 1, gradWeight, scale, gradOutput2d, finput);
199+
THTensor_(transpose)(finput, finput, 0, 1);
200+
201+
for(i = 0; i < gradBias->size[0]; i++)
202+
{
203+
long k;
204+
real sum = 0;
205+
real *data = gradOutput2d->storage->data + gradOutput2d->storageOffset + i*gradOutput2d->stride[0];
206+
for(k = 0; k < gradOutput2d->size[1]; k++)
207+
sum += data[k];
208+
(gradBias->storage->data + gradBias->storageOffset)[i] += scale*sum;
209+
}
210+
211+
THTensor_(free)(gradOutput2d);
212+
}
213+
214+
static int nn_(SpatialConvolutionMM_accGradParameters)(lua_State *L)
215+
{
216+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
217+
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
218+
real scale = luaL_optnumber(L, 4, 1);
219+
int nOutputPlane = luaT_getfieldcheckint(L, 1, "nOutputPlane");
220+
221+
THTensor *finput = luaT_getfieldcheckudata(L, 1, "finput", torch_Tensor);
222+
THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
223+
THTensor *gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
224+
225+
THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 1, "Number of output features is not equal to nOutputPlane" );
226+
227+
if(input->nDimension == 3)
228+
{
229+
nn_(SpatialConvolutionMM_accGradParameters_frame)(gradOutput, gradWeight, gradBias, finput, scale);
230+
}
231+
else
232+
{
233+
long T = input->size[0];
234+
long t;
235+
236+
for(t = 0; t < T; t++)
237+
{
238+
THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
239+
THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);
240+
241+
nn_(SpatialConvolutionMM_accGradParameters_frame)(gradOutput_t, gradWeight, gradBias, finput_t, scale);
242+
243+
THTensor_(free)(gradOutput_t);
244+
THTensor_(free)(finput_t);
245+
}
246+
}
247+
248+
return 0;
249+
}
250+
251+
static const struct luaL_Reg nn_(SpatialConvolutionMM__) [] = {
252+
{"SpatialConvolutionMM_updateOutput", nn_(SpatialConvolutionMM_updateOutput)},
253+
{"SpatialConvolutionMM_updateGradInput", nn_(SpatialConvolutionMM_updateGradInput)},
254+
{"SpatialConvolutionMM_accGradParameters", nn_(SpatialConvolutionMM_accGradParameters)},
255+
{NULL, NULL}
256+
};
257+
258+
static void nn_(SpatialConvolutionMM_init)(lua_State *L)
259+
{
260+
luaT_pushmetatable(L, torch_Tensor);
261+
luaT_registeratname(L, nn_(SpatialConvolutionMM__), "nn");
262+
lua_pop(L,1);
263+
}
264+
265+
#endif

0 commit comments

Comments
 (0)