Skip to content

Commit fb24bf5

Browse files
clownrat6xvjiarui
andauthored
Fix mIoU calculatiton range (open-mmlab#471)
* Fix fence(IoU) = 0 when training on PascalContextDataset59; * Add a test case in test_metrics() of tests/test_metrics.py to test the bug caused by torch.histc; * Update tests/test_metrics.py Co-authored-by: Jerry Jiarui XU <[email protected]> Co-authored-by: Jerry Jiarui XU <[email protected]>
1 parent 789d1a1 commit fb24bf5

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

mmseg/core/evaluation/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def intersect_and_union(pred_label,
5757

5858
intersect = pred_label[pred_label == label]
5959
area_intersect = torch.histc(
60-
intersect.float(), bins=(num_classes), min=0, max=num_classes)
60+
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
6161
area_pred_label = torch.histc(
62-
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
62+
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
6363
area_label = torch.histc(
64-
label.float(), bins=(num_classes), min=0, max=num_classes)
64+
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
6565
area_union = area_pred_label + area_label - area_intersect
6666
return area_intersect, area_union, area_pred_label, area_label
6767

tests/test_metrics.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,30 +64,36 @@ def test_metrics():
6464
ignore_index = 255
6565
results = np.random.randint(0, num_classes, size=pred_size)
6666
label = np.random.randint(0, num_classes, size=pred_size)
67+
68+
# Test the availability of arg: ignore_index.
6769
label[:, 2, 5:10] = ignore_index
70+
71+
# Test the correctness of the implementation of mIoU calculation.
6872
all_acc, acc, iou = eval_metrics(
6973
results, label, num_classes, ignore_index, metrics='mIoU')
7074
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
7175
ignore_index)
7276
assert all_acc == all_acc_l
7377
assert np.allclose(acc, acc_l)
7478
assert np.allclose(iou, iou_l)
75-
79+
# Test the correctness of the implementation of mDice calculation.
7680
all_acc, acc, dice = eval_metrics(
7781
results, label, num_classes, ignore_index, metrics='mDice')
7882
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
7983
ignore_index)
8084
assert all_acc == all_acc_l
8185
assert np.allclose(acc, acc_l)
8286
assert np.allclose(dice, dice_l)
83-
87+
# Test the correctness of the implementation of joint calculation.
8488
all_acc, acc, iou, dice = eval_metrics(
8589
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
8690
assert all_acc == all_acc_l
8791
assert np.allclose(acc, acc_l)
8892
assert np.allclose(iou, iou_l)
8993
assert np.allclose(dice, dice_l)
9094

95+
# Test the correctness of calculation when arg: num_classes is larger
96+
# than the maximum value of input maps.
9197
results = np.random.randint(0, 5, size=pred_size)
9298
label = np.random.randint(0, 4, size=pred_size)
9399
all_acc, acc, iou = eval_metrics(
@@ -121,6 +127,17 @@ def test_metrics():
121127
assert dice[-1] == -1
122128
assert iou[-1] == -1
123129

130+
# Test the bug which is caused by torch.histc.
131+
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
132+
# When the arg:bins is set to be same as arg:max,
133+
# some channels of mIoU may be nan.
134+
results = np.array([np.repeat(31, 59)])
135+
label = np.array([np.arange(59)])
136+
num_classes = 59
137+
all_acc, acc, iou = eval_metrics(
138+
results, label, num_classes, ignore_index=255, metrics='mIoU')
139+
assert not np.any(np.isnan(iou))
140+
124141

125142
def test_mean_iou():
126143
pred_size = (10, 30, 30)
@@ -182,7 +199,7 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
182199
filenames.append(filename)
183200
return filenames
184201

185-
pred_size = (10, 512, 1024)
202+
pred_size = (10, 30, 30)
186203
num_classes = 19
187204
ignore_index = 255
188205
results = np.random.randint(0, num_classes, size=pred_size)

0 commit comments

Comments
 (0)