Skip to content

Commit cb06ff1

Browse files
authored
Fix potential bugs in accuracy.py (open-mmlab#1496)
1 parent 618d3c3 commit cb06ff1

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

mmseg/models/losses/accuracy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,18 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
4545
if thresh is not None:
4646
# Only prediction values larger than thresh are counted as correct
4747
correct = correct & (pred_value > thresh).t()
48-
correct = correct[:, target != ignore_index]
48+
if ignore_index is not None:
49+
correct = correct[:, target != ignore_index]
4950
res = []
5051
eps = torch.finfo(torch.float32).eps
5152
for k in topk:
5253
# Avoid causing ZeroDivisionError when all pixels
5354
# of an image are ignored
5455
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
55-
total_num = target[target != ignore_index].numel() + eps
56+
if ignore_index is not None:
57+
total_num = target[target != ignore_index].numel() + eps
58+
else:
59+
total_num = target.numel() + eps
5660
res.append(correct_k.mul_(100.0 / total_num))
5761
return res[0] if return_single else res
5862

0 commit comments

Comments
 (0)