Skip to content

Commit a5ae723

Browse files
authored
Revert "Update to ignore zero targets"
1 parent e97095d commit a5ae723

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

lib/THCUNN/ClassNLLCriterion.cu

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
2222
// updateOutput_kernel.
2323

2424
int t = (int)*target - TH_INDEX_BASE;
25-
assert(t >= -1 && t < n_classes);
26-
if (t >= 0) {
27-
Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
28-
*output = -cur_weight * input[t];
29-
*total_weight = cur_weight;
30-
if (size_average && *total_weight > 0) {
31-
*output /= *total_weight;
32-
}
25+
assert(t >= 0 && t < n_classes);
26+
Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
27+
*output = -cur_weight * input[t];
28+
*total_weight = cur_weight;
29+
if (size_average && *total_weight > 0) {
30+
*output /= *total_weight;
3331
}
3432
}
3533

@@ -51,12 +49,10 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
5149
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
5250
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
5351
t = target[i] - TH_INDEX_BASE;
54-
assert(t >= -1 && t < n_classes);
55-
if (t >= 0) {
56-
cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
57-
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
58-
acc_weight[threadIdx.x] += cur_weight;
59-
}
52+
assert(t >= 0 && t < n_classes);
53+
cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
54+
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
55+
acc_weight[threadIdx.x] += cur_weight;
6056
}
6157
__syncthreads();
6258

@@ -95,10 +91,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
9591
}
9692
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
9793
int t = (int)*target - TH_INDEX_BASE;
98-
assert(t >= -1 && t < n_classes);
99-
if (t >= 0) {
100-
gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
101-
}
94+
assert(t >= 0 && t < n_classes);
95+
gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
10296
}
10397

10498
template <typename Dtype>
@@ -120,10 +114,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
120114

121115
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
122116
t = (int)target[i] - TH_INDEX_BASE;
123-
assert(t >= -1 && t < n_classes);
124-
if (t >= 0) {
125-
gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
126-
}
117+
assert(t >= 0 && t < n_classes);
118+
gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
127119
}
128120
}
129121

0 commit comments

Comments
 (0)