Skip to content

Commit 66b3790

Browse files
authored
[Fix] Fix the bug that when all pixels in an image is ignored, the ac… (open-mmlab#1336)
* [Fix] Fix the bug that when all pixels in an image is ignored, the accuracy calculation raises ZeroDivisionError * use eps * all close * add ignore test * add eps
1 parent 8f33d68 commit 66b3790

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

mmseg/models/losses/accuracy.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
23
import torch.nn as nn
34

45

@@ -46,10 +47,13 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
4647
correct = correct & (pred_value > thresh).t()
4748
correct = correct[:, target != ignore_index]
4849
res = []
50+
eps = torch.finfo(torch.float32).eps
4951
for k in topk:
50-
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
51-
res.append(
52-
correct_k.mul_(100.0 / target[target != ignore_index].numel()))
52+
# Avoid causing ZeroDivisionError when all pixels
53+
# of an image are ignored
54+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
55+
total_num = target[target != ignore_index].numel() + eps
56+
res.append(correct_k.mul_(100.0 / total_num))
5357
return res[0] if return_single else res
5458

5559

tests/test_models/test_losses/test_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,50 +56,56 @@ def test_accuracy():
5656
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
5757
accuracy = Accuracy(topk=1, ignore_index=None)
5858
acc = accuracy(pred, true_label)
59-
assert acc.item() == 100
59+
assert torch.allclose(acc, torch.tensor(100.0))
6060

6161
# test for ignore_index with a wrong prediction of that index
6262
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
6363
accuracy = Accuracy(topk=1, ignore_index=1)
6464
acc = accuracy(pred, true_label)
65-
assert acc.item() == 100
65+
assert torch.allclose(acc, torch.tensor(100.0))
6666

6767
# test for ignore_index 1 with a wrong prediction of other index
6868
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
6969
accuracy = Accuracy(topk=1, ignore_index=1)
7070
acc = accuracy(pred, true_label)
71-
assert acc.item() == 75
71+
assert torch.allclose(acc, torch.tensor(75.0))
7272

7373
# test for ignore_index 4 with a wrong prediction of other index
7474
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
7575
accuracy = Accuracy(topk=1, ignore_index=4)
7676
acc = accuracy(pred, true_label)
77-
assert acc.item() == 80
77+
assert torch.allclose(acc, torch.tensor(80.0))
78+
79+
# test for ignoring all the pixels
80+
true_label = torch.Tensor([2, 2, 2, 2, 2]).long()
81+
accuracy = Accuracy(topk=1, ignore_index=2)
82+
acc = accuracy(pred, true_label)
83+
assert torch.allclose(acc, torch.tensor(100.0))
7884

7985
# test for top1
8086
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
8187
accuracy = Accuracy(topk=1)
8288
acc = accuracy(pred, true_label)
83-
assert acc.item() == 100
89+
assert torch.allclose(acc, torch.tensor(100.0))
8490

8591
# test for top1 with score thresh=0.8
8692
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
8793
accuracy = Accuracy(topk=1, thresh=0.8)
8894
acc = accuracy(pred, true_label)
89-
assert acc.item() == 40
95+
assert torch.allclose(acc, torch.tensor(40.0))
9096

9197
# test for top2
9298
accuracy = Accuracy(topk=2)
9399
label = torch.Tensor([3, 2, 0, 0, 2]).long()
94100
acc = accuracy(pred, label)
95-
assert acc.item() == 100
101+
assert torch.allclose(acc, torch.tensor(100.0))
96102

97103
# test for both top1 and top2
98104
accuracy = Accuracy(topk=(1, 2))
99105
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
100106
acc = accuracy(pred, true_label)
101107
for a in acc:
102-
assert a.item() == 100
108+
assert torch.allclose(a, torch.tensor(100.0))
103109

104110
# topk is larger than pred class number
105111
with pytest.raises(AssertionError):

0 commit comments

Comments
 (0)