@@ -36,7 +36,8 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
3636 THCTensor *output,
3737 bool sizeAverage,
3838 THCTensor *weights,
39- THCTensor *total_weight)
39+ THCTensor *total_weight,
40+ long ignore_index)
4041{
4142 THNN_ (SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
4243
@@ -75,7 +76,8 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
7576 THCTensor_ (size)(state, input, 0 ),
7677 THCTensor_ (size)(state, input, 1 ),
7778 THCTensor_ (size)(state, input, 2 ) * THCTensor_ (size)(state, input, 3 ),
78- blocks_per_sample
79+ blocks_per_sample,
80+ ignore_index
7981 );
8082 THCudaCheck (cudaGetLastError ());
8183 if (sizeAverage) {
@@ -98,7 +100,8 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
98100 THCTensor *gradInput,
99101 bool sizeAverage,
100102 THCTensor *weights,
101- THCTensor *total_weight)
103+ THCTensor *total_weight,
104+ long ignore_index)
102105{
103106 THNN_ (SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
104107 THArgCheck (THCTensor_ (isContiguous)(state, gradInput), 4 ,
@@ -134,7 +137,8 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
134137 THCTensor_ (size)(state, input, 0 ),
135138 THCTensor_ (size)(state, input, 1 ),
136139 THCTensor_ (size)(state, input, 2 ) *THCTensor_ (size)(state, input, 3 ),
137- blocks_per_sample
140+ blocks_per_sample,
141+ ignore_index
138142 );
139143 THCudaCheck (cudaGetLastError ());
140144
0 commit comments