@@ -22,14 +22,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
2222 // updateOutput_kernel.
2323
2424 int t = (int )*target - TH_INDEX_BASE;
25- assert (t >= -1 && t < n_classes);
26- if (t >= 0 ) {
27- Dtype cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
28- *output = -cur_weight * input[t];
29- *total_weight = cur_weight;
30- if (size_average && *total_weight > 0 ) {
31- *output /= *total_weight;
32- }
25+ assert (t >= 0 && t < n_classes);
26+ Dtype cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
27+ *output = -cur_weight * input[t];
28+ *total_weight = cur_weight;
29+ if (size_average && *total_weight > 0 ) {
30+ *output /= *total_weight;
3331 }
3432}
3533
@@ -51,12 +49,10 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
5149 acc_weight[threadIdx .x ] = ScalarConvert<int , Acctype>::to (0 );
5250 for (i = threadIdx .x ; i < nframe; i += NTHREADS) {
5351 t = target[i] - TH_INDEX_BASE;
54- assert (t >= -1 && t < n_classes);
55- if (t >= 0 ) {
56- cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
57- shInputs[threadIdx .x ] -= input[i * ndim + t] * cur_weight;
58- acc_weight[threadIdx .x ] += cur_weight;
59- }
52+ assert (t >= 0 && t < n_classes);
53+ cur_weight = weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 );
54+ shInputs[threadIdx .x ] -= input[i * ndim + t] * cur_weight;
55+ acc_weight[threadIdx .x ] += cur_weight;
6056 }
6157 __syncthreads ();
6258
@@ -95,10 +91,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
9591 }
9692 Dtype norm = size_average ? (ScalarConvert<int , Dtype>::to (1 ) / *total_weight) : ScalarConvert<int , Dtype>::to (1 );
9793 int t = (int )*target - TH_INDEX_BASE;
98- assert (t >= -1 && t < n_classes);
99- if (t >= 0 ) {
100- gradInput[t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
101- }
94+ assert (t >= 0 && t < n_classes);
95+ gradInput[t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
10296}
10397
10498template <typename Dtype>
@@ -120,10 +114,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
120114
121115 for (i = threadIdx .x ; i < nframe; i += NTHREADS) {
122116 t = (int )target[i] - TH_INDEX_BASE;
123- assert (t >= -1 && t < n_classes);
124- if (t >= 0 ) {
125- gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
126- }
117+ assert (t >= 0 && t < n_classes);
118+ gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int , Dtype>::to (1 )) * norm;
127119 }
128120}
129121
0 commit comments