Skip to content

Commit f2b2286

Browse files
authored
Merge pull request #468 from nicholas-leonard/ClassNLLCriterion
ClassNLLCriterion ignoreIndex
2 parents 501b31c + 53f7b25 commit f2b2286

File tree

4 files changed

+71
-26
lines changed

4 files changed

+71
-26
lines changed

lib/THCUNN/ClassNLLCriterion.cu

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel1(Dtype *output,
1515
THCIndex_t *target,
1616
Dtype *weights,
1717
int size_average,
18-
int n_classes) {
18+
int n_classes,
19+
long ignore_index) {
1920
assert(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
2021

2122
// TODO: T4951791 Reuse code between updateOutput_kernel1 and
2223
// updateOutput_kernel.
2324

2425
int t = (int)*target - TH_INDEX_BASE;
25-
assert(t >= 0 && t < n_classes);
26-
Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
27-
*output = -cur_weight * input[t];
28-
*total_weight = cur_weight;
29-
if (size_average && *total_weight > 0) {
30-
*output /= *total_weight;
26+
if (t != ignore_index) {
27+
assert(t >= 0 && t < n_classes);
28+
Dtype cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
29+
*output = -cur_weight * input[t];
30+
*total_weight = cur_weight;
31+
if (size_average && *total_weight > 0) {
32+
*output /= *total_weight;
33+
}
3134
}
3235
}
3336

@@ -40,7 +43,8 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
4043
int size_average,
4144
int nframe,
4245
int ndim,
43-
int n_classes) {
46+
int n_classes,
47+
long ignore_index) {
4448
__shared__ Acctype shInputs[NTHREADS], acc_weight[NTHREADS];
4549
int i, t;
4650
Dtype cur_weight;
@@ -49,10 +53,12 @@ __global__ void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *output,
4953
acc_weight[threadIdx.x] = ScalarConvert<int, Acctype>::to(0);
5054
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
5155
t = target[i] - TH_INDEX_BASE;
52-
assert(t >= 0 && t < n_classes);
53-
cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
54-
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
55-
acc_weight[threadIdx.x] += cur_weight;
56+
if (t != ignore_index) {
57+
assert(t >= 0 && t < n_classes);
58+
cur_weight = weights ? weights[t] : ScalarConvert<int, Dtype>::to(1);
59+
shInputs[threadIdx.x] -= input[i * ndim + t] * cur_weight;
60+
acc_weight[threadIdx.x] += cur_weight;
61+
}
5662
}
5763
__syncthreads();
5864

@@ -84,15 +90,18 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel1(
8490
THCIndex_t* target,
8591
Dtype* total_weight,
8692
int size_average,
87-
int n_classes)
93+
int n_classes,
94+
long ignore_index)
8895
{
8996
if (*total_weight <= 0) {
9097
return;
9198
}
9299
Dtype norm = size_average ? (ScalarConvert<int, Dtype>::to(1) / *total_weight) : ScalarConvert<int, Dtype>::to(1);
93100
int t = (int)*target - TH_INDEX_BASE;
94-
assert(t >= 0 && t < n_classes);
95-
gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
101+
if (t != ignore_index) {
102+
assert(t >= 0 && t < n_classes);
103+
gradInput[t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
104+
}
96105
}
97106

98107
template <typename Dtype>
@@ -104,7 +113,8 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
104113
int size_average,
105114
int nframe,
106115
int ndim,
107-
int n_classes)
116+
int n_classes,
117+
long ignore_index)
108118
{
109119
if (*total_weight <= 0) {
110120
return;
@@ -114,8 +124,10 @@ __global__ void cunn_ClassNLLCriterion_updateGradInput_kernel(
114124

115125
for (i = threadIdx.x; i < nframe; i += NTHREADS) {
116126
t = (int)target[i] - TH_INDEX_BASE;
117-
assert(t >= 0 && t < n_classes);
118-
gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
127+
if (t != ignore_index) {
128+
assert(t >= 0 && t < n_classes);
129+
gradInput[i * ndim + t] = -(weights ? weights[t] : ScalarConvert<int, Dtype>::to(1)) * norm;
130+
}
119131
}
120132
}
121133

lib/THCUNN/generic/ClassNLLCriterion.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ void THNN_(ClassNLLCriterion_updateOutput)(
99
THCTensor *output,
1010
bool sizeAverage,
1111
THCTensor *weights,
12-
THCTensor *total_weight) {
12+
THCTensor *total_weight,
13+
long ignore_index) {
1314
THCUNN_check_dim_size(state, output, 1, 0, 1);
1415
THCUNN_check_dim_size(state, total_weight, 1, 0, 1);
16+
ignore_index -= TH_INDEX_BASE;
1517

1618
if (THCIndexTensor_(nDimension)(state, target) > 1) {
1719
THError("multi-target not supported");
@@ -63,7 +65,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
6365
target_data,
6466
weights_data,
6567
sizeAverage,
66-
n_classes
68+
n_classes,
69+
ignore_index
6770
);
6871

6972
} else if (THCTensor_(nDimension)(state, input) == 2) {
@@ -77,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)(
7780
sizeAverage,
7881
THCTensor_(size)(state, input, 0),
7982
THCTensor_(size)(state, input, 1),
80-
n_classes
83+
n_classes,
84+
ignore_index
8185
);
8286
}
8387
THCudaCheck(cudaGetLastError());
@@ -96,10 +100,12 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
96100
THCTensor *gradInput,
97101
bool sizeAverage,
98102
THCTensor *weights,
99-
THCTensor *total_weight) {
103+
THCTensor *total_weight,
104+
long ignore_index) {
100105
if (THCIndexTensor_(nDimension)(state, target) > 1) {
101106
THError("multi-target not supported");
102107
}
108+
ignore_index -= TH_INDEX_BASE;
103109

104110
int n_dims = THCTensor_(nDimension)(state, input);
105111
int n_classes = THCTensor_(size)(state, input, n_dims - 1);
@@ -145,7 +151,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
145151
target_data,
146152
total_weight_data,
147153
sizeAverage,
148-
n_classes
154+
n_classes,
155+
ignore_index
149156
);
150157
} else {
151158
cunn_ClassNLLCriterion_updateGradInput_kernel<real>
@@ -157,7 +164,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)(
157164
sizeAverage,
158165
THCTensor_(size)(state, input, 0),
159166
THCTensor_(size)(state, input, 1),
160-
n_classes
167+
n_classes,
168+
ignore_index
161169
);
162170
}
163171
THCudaCheck(cudaGetLastError());

lib/THCUNN/generic/THCUNN.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ TH_API void THNN_(ClassNLLCriterion_updateOutput)(
8080
THCTensor *output,
8181
bool sizeAverage,
8282
THCTensor *weights, // [OPTIONAL]
83-
THCTensor *total_weight);
83+
THCTensor *total_weight,
84+
long ignore_index);
8485

8586
TH_API void THNN_(ClassNLLCriterion_updateGradInput)(
8687
THCState *state,
@@ -89,7 +90,8 @@ TH_API void THNN_(ClassNLLCriterion_updateGradInput)(
8990
THCTensor *gradInput,
9091
bool sizeAverage,
9192
THCTensor *weights, // [OPTIONAL]
92-
THCTensor *total_weight);
93+
THCTensor *total_weight,
94+
long ignore_index);
9395

9496
TH_API void THNN_(DistKLDivCriterion_updateOutput)(
9597
THCState *state,

test.lua

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4626,6 +4626,29 @@ function cunntest.ClassNLLCriterionMultipleTargetWeights()
46264626
end
46274627
end
46284628

4629+
function cunntest.ClassNLLCriterion_ignoreIndex()
4630+
local numLabels = 10
4631+
local batchsize = 4
4632+
local ignoreIndex = -1
4633+
local cri = nn.ClassNLLCriterion(nil, nil, ignoreIndex):cuda()
4634+
local input = torch.randn(numLabels):cuda()
4635+
local target = ignoreIndex
4636+
mytester:assert(cri:forward(input, target) == 0)
4637+
mytester:assert(cri:backward(input, target):abs():sum() == 0)
4638+
local input = torch.randn(batchsize, numLabels):cuda()
4639+
local target = torch.LongTensor(batchsize):random(1,numLabels)
4640+
target[1] = ignoreIndex
4641+
target = target:cudaLong()
4642+
local output = cri:forward(input, target)
4643+
local gradInput = cri:backward(input, target):clone()
4644+
mytester:assert(gradInput[1]:abs():sum() == 0)
4645+
local input, target = input:sub(2,batchsize), target:sub(2,batchsize)
4646+
local output2 = cri:forward(input, target)
4647+
mytester:assert(math.abs(output2 - output) < 0.0000001)
4648+
local gradInput2 = cri:backward(input, target)
4649+
mytester:assertTensorEq(gradInput2, gradInput:sub(2,batchsize), 0.0000001)
4650+
end
4651+
46294652
function cunntest.TemporalMaxPooling()
46304653
local settings = {{2, 2}, {3, 3}, {4, 2}, {2, 4}, {3, 5}}
46314654

0 commit comments

Comments
 (0)