@@ -8,23 +8,35 @@ void THNN_(LogSoftMax_updateOutput)(
88 THTensor * output )
99{
1010 real * input_data , * output_data ;
11- long nframe = 0 , dim = 0 ;
11+ long nframe = 0 , dim = 0 , stride = 0 ;
1212 long t , d ;
1313
1414 if (input -> nDimension == 1 )
1515 {
1616 nframe = 1 ;
1717 dim = input -> size [0 ];
18+ stride = 1 ;
1819 }
1920 else if (input -> nDimension == 2 )
2021 {
2122 nframe = input -> size [0 ];
2223 dim = input -> size [1 ];
24+ stride = 1 ;
2325 }
24- else
26+ else if ( input -> nDimension == 3 )
2527 {
26- THArgCheck (0 , 2 , "vector or matrix expected" );
28+ nframe = 1 ;
29+ dim = input -> size [0 ];
30+ stride = input -> size [1 ]* input -> size [2 ];
2731 }
32+ else if (input -> nDimension == 4 )
33+ {
34+ nframe = input -> size [0 ];
35+ dim = input -> size [1 ];
36+ stride = input -> size [2 ]* input -> size [3 ];
37+ }
38+ else
39+ THArgCheck (0 , 2 , "1D, 2D, 3D or 4D tensor expected" );
2840
2941 input = THTensor_ (newContiguous )(input );
3042 THTensor_ (resizeAs )(output , input );
@@ -35,22 +47,22 @@ void THNN_(LogSoftMax_updateOutput)(
3547 accreal logsum ;
3648 real maxInput ;
3749 #pragma omp parallel for private(t, d, maxInput, logsum, input_data, output_data)
38- for (t = 0 ; t < nframe ; t ++ )
50+ for (t = 0 ; t < stride * nframe ; t ++ )
3951 {
4052 logsum = 0 ;
4153 maxInput = - THInf ;
42- input_data = input_data0 + dim * t ;
43- output_data = output_data0 + dim * t ;
54+ input_data = input_data0 + ( t / stride ) * dim * stride + t % stride ;
55+ output_data = output_data0 + ( t / stride ) * dim * stride + t % stride ;
4456
4557 for (d = 0 ; d < dim ; d ++ )
46- maxInput = THMax (maxInput , input_data [d ]);
58+ maxInput = THMax (maxInput , input_data [d * stride ]);
4759
4860 for (d = 0 ; d < dim ; d ++ )
49- logsum += exp (input_data [d ] - maxInput );
61+ logsum += exp (input_data [d * stride ] - maxInput );
5062 logsum = maxInput + log (logsum );
5163
5264 for (d = 0 ; d < dim ; d ++ )
53- output_data [d ] = input_data [d ] - logsum ;
65+ output_data [d * stride ] = input_data [d * stride ] - logsum ;
5466 }
5567
5668 THTensor_ (free )(input );
@@ -66,45 +78,61 @@ void THNN_(LogSoftMax_updateGradInput)(
6678 THNN_CHECK_SHAPE (input , gradOutput );
6779 gradOutput = THTensor_ (newContiguous )(gradOutput );
6880 real * gradInput_data , * gradOutput_data , * output_data ;
69- long nframe = 0 , dim = 0 ;
81+ long nframe = 0 , dim = 0 , stride = 0 ;
7082 long t , d ;
7183
7284 if (output -> nDimension == 1 )
7385 {
7486 nframe = 1 ;
7587 dim = output -> size [0 ];
88+ stride = 1 ;
7689 }
7790 else if (output -> nDimension == 2 )
7891 {
7992 nframe = output -> size [0 ];
8093 dim = output -> size [1 ];
94+ stride = 1 ;
8195 }
82- else
96+ else if ( output -> nDimension == 3 )
8397 {
84- THError ("vector or matrix expected" );
98+ nframe = 1 ;
99+ dim = output -> size [0 ];
100+ stride = output -> size [1 ]* output -> size [2 ];
85101 }
102+ else if (output -> nDimension == 4 )
103+ {
104+ nframe = output -> size [0 ];
105+ dim = output -> size [1 ];
106+ stride = output -> size [2 ]* output -> size [3 ];
107+ }
108+ else
109+ THError ("1D, 2D, 3D or 4D tensor expected" );
110+
111+ output = THTensor_ (newContiguous )(output );
112+ gradOutput = THTensor_ (newContiguous )(gradOutput );
86113
87114 THTensor_ (resizeAs )(gradInput , output );
88115 real * gradInput_data0 = THTensor_ (data )(gradInput );
89116 real * output_data0 = THTensor_ (data )(output );
90117 real * gradOutput_data0 = THTensor_ (data )(gradOutput );
91118 accreal sum ;
92119 #pragma omp parallel for private(t, sum, d, gradInput_data, output_data, gradOutput_data)
93- for (t = 0 ; t < nframe ; t ++ )
120+ for (t = 0 ; t < stride * nframe ; t ++ )
94121 {
95122 sum = 0 ;
96- gradInput_data = gradInput_data0 + dim * t ;
97- output_data = output_data0 + dim * t ;
98- gradOutput_data = gradOutput_data0 + dim * t ;
123+ gradInput_data = gradInput_data0 + ( t / stride ) * dim * stride + t % stride ;
124+ output_data = output_data0 + ( t / stride ) * dim * stride + t % stride ;
125+ gradOutput_data = gradOutput_data0 + ( t / stride ) * dim * stride + t % stride ;
99126
100127 for (d = 0 ; d < dim ; d ++ )
101- sum += gradOutput_data [d ];
128+ sum += gradOutput_data [d * stride ];
102129
103130 for (d = 0 ; d < dim ; d ++ )
104- gradInput_data [d ] = gradOutput_data [d ] - exp (output_data [d ])* sum ;
131+ gradInput_data [d * stride ] = gradOutput_data [d * stride ] - exp (output_data [d * stride ])* sum ;
105132 }
106133
107134 THTensor_ (free )(gradOutput );
135+ THTensor_ (free )(output );
108136}
109137
110138#endif
0 commit comments