@@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
1515 THCIndex_t *target,
1616 Dtype *weights,
1717 int size_average,
18- int n_classes) {
18+ int n_classes,
19+ long ignore_index) {
1920 assert (threadIdx .x == 0 && threadIdx .y == 0 && threadIdx .z == 0 );
2021
2122 // TODO: T4951791 Reuse code between updateOutput_kernel1 and
2223 // updateOutput_kernel.
2324
2425 int t = (int )*target - TH_INDEX_BASE;
25- assert (t >= 0 && t < n_classes);
26- Dtype cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
27- *output = -cur_weight * input[t];
28- *total_weight = cur_weight;
29- if (size_average && *total_weight > 0 ) {
30- *output /= *total_weight;
26+ if (t != ignore_index) {
27+ assert (t >= 0 && t < n_classes);
28+ Dtype cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
29+ *output = -cur_weight * input[t];
30+ *total_weight = cur_weight;
31+ if (size_average && *total_weight > 0 ) {
32+ *output /= *total_weight;
33+ }
3134 }
3235}
3336
@@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
4043 int size_average,
4144 int nframe,
4245 int ndim,
43- int n_classes) {
46+ int n_classes,
47+ long ignore_index) {
4448 __shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS];
4549 int i, t;
4650 Dtype cur_weight;
@@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
4953 acc_weight[threadIdx .x ] = ScalarConvert<int , Acctype>::to (0 );
5054 for (i = threadIdx .x ; i < nframe; i += NTHREADS) {
5155 t = target[i] - TH_INDEX_BASE;
52- assert (t >= 0 && t < n_classes);
53- cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
54- shInputs[threadIdx .x ] -= input[i * ndim + t] * cur_weight;
55- acc_weight[threadIdx .x ] += cur_weight;
56+ if (t != ignore_index) {
57+ assert (t >= 0 && t < n_classes);
58+ cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
59+ shInputs[threadIdx .x ] -= input[i * ndim + t] * cur_weight;
60+ acc_weight[threadIdx .x ] += cur_weight;
61+ }
5662 }
5763 __syncthreads ();
5864
@@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
8490 THCIndex_t* target,
8591 Dtype* total_weight,
8692 int size_average,
87- int n_classes)
93+ int n_classes,
94+ long ignore_index)
8895{
8996 if (*total_weight <= 0 ) {
9097 return ;
9198 }
9299 Dtype norm = size_average ? (ScalarConvert<int , Dtype>::to (1 ) / *total_weight) : ScalarConvert<int , Dtype>::to (1 );
93100 int t = (int )*target - TH_INDEX_BASE;
94- assert (t >= 0 && t < n_classes);
95- gradInput[t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
101+ if (t != ignore_index) {
102+ assert (t >= 0 && t < n_classes);
103+ gradInput[t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
104+ }
96105}
97106
98107template <typename Dtype>
@@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
104113 int size_average,
105114 int nframe,
106115 int ndim,
107- int n_classes)
116+ int n_classes,
117+ long ignore_index)
108118{
109119 if (*total_weight <= 0 ) {
110120 return ;
@@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
114124
115125 for (i = threadIdx .x ; i < nframe; i += NTHREADS) {
116126 t = (int )target[i] - TH_INDEX_BASE;
117- assert (t >= 0 && t < n_classes);
118- gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
127+ if (t != ignore_index) {
128+ assert (t >= 0 && t < n_classes);
129+ gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
130+ }
119131 }
120132}
121133
0 commit comments