Skip to content

Commit 768c3ee

Browse files
sshuairxvjiarui
andauthored
fix acc and iou compute nan problem (open-mmlab#116)
* fix acc and iou compute nan problem * fix acc and iou compute nan problem * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * Update mmseg/core/evaluation/mean_iou.py * Update mean_iou.py * Update mean_iou.py Co-authored-by: Jerry Jiarui XU <[email protected]>
1 parent e3f6d57 commit 768c3ee

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

mmseg/core/evaluation/mean_iou.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
3434
return area_intersect, area_union, area_pred_label, area_label
3535

3636

37-
def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
37+
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
3838
"""Calculate Intersection and Union (IoU)
3939
4040
Args:
4141
results (list[ndarray]): List of prediction segmentation maps
4242
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
4343
num_classes (int): Number of categories
4444
ignore_index (int): Index that will be ignored in evaluation.
45+
nan_to_num (int, optional): If specified, NaN values will be replaced
46+
by the numbers defined by the user. Default: None.
4547
4648
Returns:
4749
float: Overall accuracy on all images.
@@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
6668
all_acc = total_area_intersect.sum() / total_area_label.sum()
6769
acc = total_area_intersect / total_area_label
6870
iou = total_area_intersect / total_area_union
69-
71+
if nan_to_num is not None:
72+
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
73+
np.nan_to_num(iou, nan=nan_to_num)
7074
return all_acc, acc, iou

tests/test_mean_iou.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,10 @@ def test_mean_iou():
5454
assert all_acc == all_acc_l
5555
assert np.allclose(acc, acc_l)
5656
assert np.allclose(iou, iou_l)
57+
58+
results = np.random.randint(0, 5, size=pred_size)
59+
label = np.random.randint(0, 4, size=pred_size)
60+
all_acc, acc, iou = mean_iou(
61+
results, label, num_classes, ignore_index=255, nan_to_num=-1)
62+
assert acc[-1] == -1
63+
assert iou[-1] == -1

0 commit comments

Comments
 (0)