Skip to content

Commit d6bc264

Browse files
fyusoumith
authored andcommitted
Add ignore_index to NLLLoss2d
1 parent a5a8ab1 commit d6bc264

File tree

9 files changed

+54
-25
lines changed

9 files changed

+54
-25
lines changed

test/common_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,13 @@
283283
target=torch.rand(2, 5, 5).mul(3).floor().long(),
284284
desc='weights'
285285
),
286+
dict(
287+
module_name='NLLLoss2d',
288+
constructor_args=(None, True, 3),
289+
input_size=(2, 3, 5, 5),
290+
target=torch.rand(2, 5, 5).mul(4).floor().long(),
291+
desc='ignore_index'
292+
),
286293
dict(
287294
module_name='HingeEmbeddingLoss',
288295
input=torch.rand(10),

torch/legacy/nn/SpatialClassNLLCriterion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,28 @@
44

55
class SpatialClassNLLCriterion(Criterion):
66

7-
def __init__(self, weights=None, sizeAverage=True):
7+
def __init__(self, weights=None, sizeAverage=True, ignore_index=-100):
88
assert weights is None or weights.dim() == 1
99
super(SpatialClassNLLCriterion, self).__init__()
1010
self.sizeAverage = sizeAverage
1111
self.weights = weights
12+
self.ignore_index = ignore_index
1213

1314
self.output_tensor = torch.zeros(1)
1415
self.total_weight_tensor = torch.ones(1)
1516

1617
def updateOutput(self, input, target):
18+
if not hasattr(self, 'ignore_index'):
19+
self.ignore_index = -100
1720
self._backend.SpatialClassNLLCriterion_updateOutput(
1821
self._backend.library_state,
1922
input,
2023
target,
2124
self.output_tensor,
2225
self.sizeAverage,
2326
self.weights,
24-
self.total_weight_tensor
27+
self.total_weight_tensor,
28+
self.ignore_index
2529
)
2630
self.output = self.output_tensor[0]
2731
return self.output
@@ -35,6 +39,7 @@ def updateGradInput(self, input, target):
3539
self.gradInput,
3640
self.sizeAverage,
3741
self.weights,
38-
self.total_weight_tensor
42+
self.total_weight_tensor,
43+
self.ignore_index
3944
)
4045
return self.gradInput

torch/lib/THCUNN/SpatialClassNLLCriterion.cu

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
1818
int batch_size,
1919
int n_classes,
2020
int map_nelem,
21-
int blocks_per_sample)
21+
int blocks_per_sample,
22+
long ignore_index)
2223
{
2324
__shared__ AccumT partial_sums[CUDA_NUM_THREADS];
2425

@@ -35,10 +36,12 @@ __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
3536
i < map_nelem;
3637
i += step) {
3738
t = target[toffset + i] - TH_INDEX_BASE;
38-
assert(t >= 0 && t < n_classes);
39-
cur_weight = weights ? weights[t] : ScalarConvert<int, T>::to(1);
40-
input_sum -= input[ioffset + i + map_nelem * t] * cur_weight;
41-
acc_weight += cur_weight;
39+
if (t != ignore_index) {
40+
assert(t >= 0 && t < n_classes);
41+
cur_weight = weights ? weights[t] : ScalarConvert<int, T>::to(1);
42+
input_sum -= input[ioffset + i + map_nelem * t] * cur_weight;
43+
acc_weight += cur_weight;
44+
}
4245
}
4346

4447
__syncthreads();
@@ -71,7 +74,8 @@ __global__ void cunn_SpatialClassNLLCriterion_updateGradInput_kernel(
7174
int batch_size,
7275
int n_classes,
7376
int map_nelem,
74-
int blocks_per_sample)
77+
int blocks_per_sample,
78+
long ignore_index)
7579
{
7680
if (*total_weight <= 0)
7781
return;
@@ -87,8 +91,10 @@ __global__ void cunn_SpatialClassNLLCriterion_updateGradInput_kernel(
8791
i < map_nelem;
8892
i += step) {
8993
t = (int)target[toffset + i] - TH_INDEX_BASE;
90-
assert(t >= 0 && t < n_classes);
91-
gradInput[ioffset + i + map_nelem * t] = -(weights ? weights[t] : ScalarConvert<int, T>::to(1)) * norm;
94+
if (t != ignore_index) {
95+
assert(t >= 0 && t < n_classes);
96+
gradInput[ioffset + i + map_nelem * t] = -(weights ? weights[t] : ScalarConvert<int, T>::to(1)) * norm;
97+
}
9298
}
9399
}
94100

torch/lib/THCUNN/generic/SpatialClassNLLCriterion.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
3636
THCTensor *output,
3737
bool sizeAverage,
3838
THCTensor *weights,
39-
THCTensor *total_weight)
39+
THCTensor *total_weight,
40+
long ignore_index)
4041
{
4142
THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
4243

@@ -75,7 +76,8 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
7576
THCTensor_(size)(state, input, 0),
7677
THCTensor_(size)(state, input, 1),
7778
THCTensor_(size)(state, input, 2) * THCTensor_(size)(state, input, 3),
78-
blocks_per_sample
79+
blocks_per_sample,
80+
ignore_index
7981
);
8082
THCudaCheck(cudaGetLastError());
8183
if (sizeAverage) {
@@ -98,7 +100,8 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
98100
THCTensor *gradInput,
99101
bool sizeAverage,
100102
THCTensor *weights,
101-
THCTensor *total_weight)
103+
THCTensor *total_weight,
104+
long ignore_index)
102105
{
103106
THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
104107
THArgCheck(THCTensor_(isContiguous)(state, gradInput), 4,
@@ -134,7 +137,8 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
134137
THCTensor_(size)(state, input, 0),
135138
THCTensor_(size)(state, input, 1),
136139
THCTensor_(size)(state, input, 2) *THCTensor_(size)(state, input, 3),
137-
blocks_per_sample
140+
blocks_per_sample,
141+
ignore_index
138142
);
139143
THCudaCheck(cudaGetLastError());
140144

torch/lib/THCUNN/generic/THCUNN.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,8 @@ TH_API void THNN_(SpatialClassNLLCriterion_updateOutput)(
553553
THCTensor *output,
554554
bool sizeAverage,
555555
THCTensor *weights, // [OPTIONAL]
556-
THCTensor *total_weight);
556+
THCTensor *total_weight,
557+
long ignore_index);
557558

558559
TH_API void THNN_(SpatialClassNLLCriterion_updateGradInput)(
559560
THCState *state,
@@ -562,7 +563,8 @@ TH_API void THNN_(SpatialClassNLLCriterion_updateGradInput)(
562563
THCTensor *gradInput,
563564
bool sizeAverage,
564565
THCTensor *weights, // [OPTIONAL]
565-
THCTensor *total_weight);
566+
THCTensor *total_weight,
567+
long ignore_index);
566568

567569
TH_API void THNN_(SpatialConvolutionLocal_updateOutput)(
568570
THCState *state,

torch/lib/THNN/generic/SpatialClassNLLCriterion.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
3434
THTensor *output,
3535
bool sizeAverage,
3636
THTensor *weights,
37-
THTensor *total_weight)
37+
THTensor *total_weight,
38+
long ignore_index)
3839
{
3940
INITIAL_CHECK;
4041

@@ -58,6 +59,7 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
5859
for (int b = 0; b < batch_size; b++) {
5960
for (int elem = 0; elem < map_size; elem++) {
6061
int cur_target = target_data[b * map_size + elem] - TH_INDEX_BASE;
62+
if (cur_target == ignore_index) continue;
6163
THAssert(cur_target >= 0 && cur_target < n_classes);
6264

6365
real cur_weight = weights ? weights_data[cur_target] : 1.0f;
@@ -84,7 +86,8 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
8486
THTensor *gradInput,
8587
bool sizeAverage,
8688
THTensor *weights,
87-
THTensor *total_weight)
89+
THTensor *total_weight,
90+
long ignore_index)
8891
{
8992
INITIAL_CHECK;
9093
THArgCheck(THTensor_(isContiguous)(gradInput), 4,
@@ -114,6 +117,7 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
114117
int elem;
115118
for (elem = 0; elem < map_size; elem++) {
116119
int cur_target = target_data[b * map_size + elem] - TH_INDEX_BASE;
120+
if (cur_target == ignore_index) continue;
117121
THAssert(cur_target >= 0 && cur_target < n_classes);
118122

119123
gradInput_data[b * sample_size + cur_target * map_size + elem] =

torch/lib/THNN/generic/THNN.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,18 @@ TH_API void THNN_(SpatialClassNLLCriterion_updateOutput)(
6666
THTensor *output, // [OUT] a one-element tensor with loss
6767
bool sizeAverage, // if true, the loss will be normalized by batch size and class weights
6868
THTensor *weights, // [OPTIONAL] class weights
69-
THTensor *total_weight); // [BUFFER]
69+
THTensor *total_weight, // [BUFFER]
70+
long ignore_index); // target index to ignore (loss = 0, gradInput = 0)
7071
TH_API void THNN_(SpatialClassNLLCriterion_updateGradInput)(
7172
THNNState *state, // library's state
7273
THTensor *input, // input tensor (4D)
7374
THIndexTensor *target, // tensor containing indexes of target classes (3D)
7475
THTensor *gradInput, // [OUT] gradient w.r.t. input
7576
bool sizeAverage, // if true, the loss will be normalized by batch size and class weights
7677
THTensor *weights, // [OPTIONAL] class weights
77-
THTensor *total_weight); // [BUFFER]
78+
THTensor *total_weight, // [BUFFER]
79+
long ignore_index); // target index to ignore (loss = 0, gradInput = 0)
80+
7881

7982
TH_API void THNN_(ELU_updateOutput)(
8083
THNNState *state, // library's state

torch/nn/functional.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100):
595595
if dim == 2:
596596
f = _functions.thnn.NLLLoss(size_average, ignore_index, weight=weight)
597597
elif dim == 4:
598-
if ignore_index != -100:
599-
raise ValueError('ignore_index is not supported for 4-D inputs')
600-
f = _functions.thnn.NLLLoss2d(size_average, weight=weight)
598+
f = _functions.thnn.NLLLoss2d(size_average, ignore_index, weight=weight)
601599
else:
602600
raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))
603601
return f(input, target)

torch/nn/modules/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def forward(self, input, target):
124124
self.ignore_index)
125125

126126

127-
class NLLLoss2d(_WeightedLoss):
127+
class NLLLoss2d(NLLLoss):
128128
r"""This is negative log likehood loss, but for image inputs. It computes
129129
NLL loss per-pixel.
130130

0 commit comments

Comments
 (0)