@@ -9,12 +9,14 @@ void THNN_(ClassNLLCriterion_updateOutput)(
99 THTensor * output ,
1010 bool sizeAverage ,
1111 THTensor * weights ,
12- THTensor * total_weight )
12+ THTensor * total_weight ,
13+ long ignore_index )
1314{
1415 THNN_CHECK_DIM_SIZE (output , 1 , 0 , 1 );
1516 THNN_CHECK_DIM_SIZE (total_weight , 1 , 0 , 1 );
1617 int n_dims = THTensor_ (nDimension )(input );
1718 int n_classes = THTensor_ (size )(input , n_dims - 1 );
19+ ignore_index -= TH_INDEX_BASE ;
1820
1921 if (THIndexTensor_ (nDimension )(target ) > 1 ) {
2022 THError ("multi-target not supported" );
@@ -42,9 +44,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
4244
4345 if (THTensor_ (nDimension )(input ) == 1 ) {
4446 int cur_target = target_data [0 ] - TH_INDEX_BASE ;
45- THAssert (cur_target >= 0 && cur_target < n_classes );
46- total_weight_data [0 ] = weights ? weights_data [cur_target ] : 1.0f ;
47- output_data [0 ] = - input_data [cur_target ] * total_weight_data [0 ];
47+ if (cur_target != ignore_index ) {
48+ THAssert (cur_target >= 0 && cur_target < n_classes );
49+ total_weight_data [0 ] = weights ? weights_data [cur_target ] : 1.0f ;
50+ output_data [0 ] = - input_data [cur_target ] * total_weight_data [0 ];
51+ }
4852 } else if (THTensor_ (nDimension )(input ) == 2 ) {
4953 int batch_size = THTensor_ (size )(input , 0 );
5054 THAssert (THIndexTensor_ (size )(target , 0 ) == batch_size );
@@ -54,11 +58,13 @@ void THNN_(ClassNLLCriterion_updateOutput)(
5458 int i ;
5559 for (i = 0 ; i < batch_size ; i ++ ) {
5660 int cur_target = target_data [i ] - TH_INDEX_BASE ;
57- THAssert (cur_target >= 0 && cur_target < n_classes );
61+ if (cur_target != ignore_index ) {
62+ THAssert (cur_target >= 0 && cur_target < n_classes );
5863
59- real cur_weight = weights ? weights_data [cur_target ] : 1.0f ;
60- total_weight_data [0 ] += cur_weight ;
61- output_data [0 ] -= input_data [i * n_target + cur_target ] * cur_weight ;
64+ real cur_weight = weights ? weights_data [cur_target ] : 1.0f ;
65+ total_weight_data [0 ] += cur_weight ;
66+ output_data [0 ] -= input_data [i * n_target + cur_target ] * cur_weight ;
67+ }
6268 }
6369 }
6470
@@ -80,10 +86,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
8086 THTensor * gradInput ,
8187 bool sizeAverage ,
8288 THTensor * weights ,
83- THTensor * total_weight )
89+ THTensor * total_weight ,
90+ long ignore_index )
8491{
8592 int n_dims = THTensor_ (nDimension )(input );
8693 int n_classes = THTensor_ (size )(input , n_dims - 1 );
94+ ignore_index -= TH_INDEX_BASE ;
8795
8896 if (!THTensor_ (isContiguous )(gradInput )) {
8997 THError ("gradInput must be contiguous" );
@@ -102,7 +110,7 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
102110 if (THTensor_ (nDimension )(input ) > 2 ) {
103111 THError ("input tensor should be 1D or 2D" );
104112 }
105-
113+
106114 if (weights && THTensor_ (nElement )(weights ) != n_classes ) {
107115 THError ("weight tensor should be defined either for all or no classes" );
108116 }
@@ -116,10 +124,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
116124
117125 if (THTensor_ (nDimension )(input ) == 1 ) {
118126 int cur_target = target_data [0 ] - TH_INDEX_BASE ;
119- THAssert (cur_target >= 0 && cur_target < n_classes );
127+ if (cur_target != ignore_index ) {
128+ THAssert (cur_target >= 0 && cur_target < n_classes );
120129
121- gradInput_data [cur_target ] =
122- (!sizeAverage && weights ) ? - weights_data [cur_target ] : -1 ;
130+ gradInput_data [cur_target ] =
131+ (!sizeAverage && weights ) ? - weights_data [cur_target ] : -1 ;
132+ }
123133
124134 } else if (THTensor_ (nDimension )(input ) == 2 ) {
125135 int batch_size = THTensor_ (size )(input , 0 );
@@ -131,13 +141,15 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
131141 for (i = 0 ; i < batch_size ; i ++ ){
132142 int cur_target = target_data [i ] - TH_INDEX_BASE ;
133143
134- THAssert (cur_target >= 0 && cur_target < n_classes );
144+ if (cur_target != ignore_index ) {
145+ THAssert (cur_target >= 0 && cur_target < n_classes );
135146
136- gradInput_data [i * n_target + cur_target ] =
137- - (weights ? weights_data [cur_target ] : 1.0f );
147+ gradInput_data [i * n_target + cur_target ] =
148+ - (weights ? weights_data [cur_target ] : 1.0f );
138149
139- if (sizeAverage && * total_weight_data ) {
140- gradInput_data [i * n_target + cur_target ] /= * total_weight_data ;
150+ if (sizeAverage && * total_weight_data ) {
151+ gradInput_data [i * n_target + cur_target ] /= * total_weight_data ;
152+ }
141153 }
142154 }
143155 }
0 commit comments