Skip to content

Commit 346f70d

Browse files
authored
[Fix] Make accuracy take into account ignore_index (open-mmlab#1259)
* make accuracy take into account ignore_index * add UT for accuracy
1 parent 0934a57 commit 346f70d

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

mmseg/models/decode_heads/decode_head.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,5 +261,6 @@ def losses(self, seg_logit, seg_label):
261261
weight=seg_weight,
262262
ignore_index=self.ignore_index)
263263

264-
loss['acc_seg'] = accuracy(seg_logit, seg_label)
264+
loss['acc_seg'] = accuracy(
265+
seg_logit, seg_label, ignore_index=self.ignore_index)
265266
return loss

mmseg/models/decode_heads/point_head.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def losses(self, point_logits, point_label):
264264
loss['point' + loss_module.loss_name] = loss_module(
265265
point_logits, point_label, ignore_index=self.ignore_index)
266266

267-
loss['acc_point'] = accuracy(point_logits, point_label)
267+
loss['acc_point'] = accuracy(
268+
point_logits, point_label, ignore_index=self.ignore_index)
268269
return loss
269270

270271
def get_points_train(self, seg_logits, uncertainty_func, cfg):

mmseg/models/losses/accuracy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import torch.nn as nn
33

44

5-
def accuracy(pred, target, topk=1, thresh=None):
5+
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
66
"""Calculate accuracy according to the prediction and target.
77
88
Args:
99
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
1010
target (torch.Tensor): The target of each prediction, shape (N, , ...)
11+
ignore_index (int | None): The label index to be ignored. Default: None
1112
topk (int | tuple[int], optional): If the predictions in ``topk``
1213
matches the target, the predictions will be regarded as
1314
correct ones. Defaults to 1.
@@ -43,17 +44,19 @@ def accuracy(pred, target, topk=1, thresh=None):
4344
if thresh is not None:
4445
# Only prediction values larger than thresh are counted as correct
4546
correct = correct & (pred_value > thresh).t()
47+
correct = correct[:, target != ignore_index]
4648
res = []
4749
for k in topk:
4850
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
49-
res.append(correct_k.mul_(100.0 / target.numel()))
51+
res.append(
52+
correct_k.mul_(100.0 / target[target != ignore_index].numel()))
5053
return res[0] if return_single else res
5154

5255

5356
class Accuracy(nn.Module):
5457
"""Accuracy calculation module."""
5558

56-
def __init__(self, topk=(1, ), thresh=None):
59+
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
5760
"""Module to calculate the accuracy.
5861
5962
Args:
@@ -65,6 +68,7 @@ def __init__(self, topk=(1, ), thresh=None):
6568
super().__init__()
6669
self.topk = topk
6770
self.thresh = thresh
71+
self.ignore_index = ignore_index
6872

6973
def forward(self, pred, target):
7074
"""Forward function to calculate accuracy.
@@ -76,4 +80,5 @@ def forward(self, pred, target):
7680
Returns:
7781
tuple[float]: The accuracies under different topk criterions.
7882
"""
79-
return accuracy(pred, target, self.topk, self.thresh)
83+
return accuracy(pred, target, self.topk, self.thresh,
84+
self.ignore_index)

tests/test_models/test_losses/test_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,30 @@ def test_accuracy():
5252
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
5353
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
5454
[0.0, 0.0, 0.99, 0]])
55+
# test for ignore_index
56+
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
57+
accuracy = Accuracy(topk=1, ignore_index=None)
58+
acc = accuracy(pred, true_label)
59+
assert acc.item() == 100
60+
61+
# test for ignore_index with a wrong prediction of that index
62+
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
63+
accuracy = Accuracy(topk=1, ignore_index=1)
64+
acc = accuracy(pred, true_label)
65+
assert acc.item() == 100
66+
67+
# test for ignore_index 1 with a wrong prediction of other index
68+
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
69+
accuracy = Accuracy(topk=1, ignore_index=1)
70+
acc = accuracy(pred, true_label)
71+
assert acc.item() == 75
72+
73+
# test for ignore_index 4 with a wrong prediction of other index
74+
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
75+
accuracy = Accuracy(topk=1, ignore_index=4)
76+
acc = accuracy(pred, true_label)
77+
assert acc.item() == 80
78+
5579
# test for top1
5680
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
5781
accuracy = Accuracy(topk=1)

0 commit comments

Comments
 (0)