Skip to content

Commit 5db118e

Browse files
fmassasoumith
authored andcommitted
Update LogSoftMax to work in spatial domain
1 parent 60a8a9e commit 5db118e

File tree

1 file changed

+46
-18
lines changed

1 file changed

+46
-18
lines changed

generic/LogSoftMax.c

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,35 @@ void THNN_(LogSoftMax_updateOutput)(
88
THTensor *output)
99
{
1010
real *input_data, *output_data;
11-
long nframe = 0, dim = 0;
11+
long nframe = 0, dim = 0, stride = 0;
1212
long t, d;
1313

1414
if (input->nDimension == 1)
1515
{
1616
nframe = 1;
1717
dim = input->size[0];
18+
stride = 1;
1819
}
1920
else if (input->nDimension == 2)
2021
{
2122
nframe = input->size[0];
2223
dim = input->size[1];
24+
stride = 1;
2325
}
24-
else
26+
else if (input->nDimension == 3)
2527
{
26-
THArgCheck(0, 2, "vector or matrix expected");
28+
nframe = 1;
29+
dim = input->size[0];
30+
stride = input->size[1]*input->size[2];
2731
}
32+
else if (input->nDimension == 4)
33+
{
34+
nframe = input->size[0];
35+
dim = input->size[1];
36+
stride = input->size[2]*input->size[3];
37+
}
38+
else
39+
THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected");
2840

2941
input = THTensor_(newContiguous)(input);
3042
THTensor_(resizeAs)(output, input);
@@ -35,22 +47,22 @@ void THNN_(LogSoftMax_updateOutput)(
3547
accreal logsum;
3648
real maxInput;
3749
#pragma omp parallel for private(t, d, maxInput, logsum, input_data, output_data)
38-
for (t = 0; t < nframe; t++)
50+
for (t = 0; t < stride*nframe; t++)
3951
{
4052
logsum = 0;
4153
maxInput = -THInf;
42-
input_data = input_data0 + dim*t;
43-
output_data = output_data0 + dim*t;
54+
input_data = input_data0 + (t/stride)*dim*stride + t % stride;
55+
output_data = output_data0 + (t/stride)*dim*stride + t % stride;
4456

4557
for (d = 0; d < dim; d++)
46-
maxInput = THMax(maxInput, input_data[d]);
58+
maxInput = THMax(maxInput, input_data[d*stride]);
4759

4860
for (d = 0; d < dim; d++)
49-
logsum += exp(input_data[d] - maxInput);
61+
logsum += exp(input_data[d*stride] - maxInput);
5062
logsum = maxInput + log(logsum);
5163

5264
for (d = 0; d < dim; d++)
53-
output_data[d] = input_data[d] - logsum;
65+
output_data[d*stride] = input_data[d*stride] - logsum;
5466
}
5567

5668
THTensor_(free)(input);
@@ -66,45 +78,61 @@ void THNN_(LogSoftMax_updateGradInput)(
6678
THNN_CHECK_SHAPE(input, gradOutput);
6779
gradOutput = THTensor_(newContiguous)(gradOutput);
6880
real *gradInput_data, *gradOutput_data, *output_data;
69-
long nframe = 0, dim = 0;
81+
long nframe = 0, dim = 0, stride = 0;
7082
long t, d;
7183

7284
if (output->nDimension == 1)
7385
{
7486
nframe = 1;
7587
dim = output->size[0];
88+
stride = 1;
7689
}
7790
else if (output->nDimension == 2)
7891
{
7992
nframe = output->size[0];
8093
dim = output->size[1];
94+
stride = 1;
8195
}
82-
else
96+
else if (output->nDimension == 3)
8397
{
84-
THError("vector or matrix expected");
98+
nframe = 1;
99+
dim = output->size[0];
100+
stride = output->size[1]*output->size[2];
85101
}
102+
else if (output->nDimension == 4)
103+
{
104+
nframe = output->size[0];
105+
dim = output->size[1];
106+
stride = output->size[2]*output->size[3];
107+
}
108+
else
109+
THError("1D, 2D, 3D or 4D tensor expected");
110+
111+
output = THTensor_(newContiguous)(output);
112+
gradOutput = THTensor_(newContiguous)(gradOutput);
86113

87114
THTensor_(resizeAs)(gradInput, output);
88115
real *gradInput_data0 = THTensor_(data)(gradInput);
89116
real *output_data0 = THTensor_(data)(output);
90117
real *gradOutput_data0 = THTensor_(data)(gradOutput);
91118
accreal sum;
92119
#pragma omp parallel for private(t, sum, d, gradInput_data, output_data, gradOutput_data)
93-
for (t = 0; t < nframe; t++)
120+
for (t = 0; t < stride*nframe; t++)
94121
{
95122
sum = 0;
96-
gradInput_data = gradInput_data0 + dim*t;
97-
output_data = output_data0 + dim*t;
98-
gradOutput_data = gradOutput_data0 + dim*t;
123+
gradInput_data = gradInput_data0 + (t/stride)*dim*stride + t % stride;
124+
output_data = output_data0 + (t/stride)*dim*stride + t % stride;
125+
gradOutput_data = gradOutput_data0 + (t/stride)*dim*stride + t % stride;
99126

100127
for (d = 0; d < dim; d++)
101-
sum += gradOutput_data[d];
128+
sum += gradOutput_data[d*stride];
102129

103130
for (d = 0; d < dim; d++)
104-
gradInput_data[d] = gradOutput_data[d] - exp(output_data[d])*sum;
131+
gradInput_data[d*stride] = gradOutput_data[d*stride] - exp(output_data[d*stride])*sum;
105132
}
106133

107134
THTensor_(free)(gradOutput);
135+
THTensor_(free)(output);
108136
}
109137

110138
#endif

0 commit comments

Comments
 (0)