@@ -64,30 +64,36 @@ def test_metrics():
64
64
ignore_index = 255
65
65
results = np .random .randint (0 , num_classes , size = pred_size )
66
66
label = np .random .randint (0 , num_classes , size = pred_size )
67
+
68
+ # Test the availability of arg: ignore_index.
67
69
label [:, 2 , 5 :10 ] = ignore_index
70
+
71
+ # Test the correctness of the implementation of mIoU calculation.
68
72
all_acc , acc , iou = eval_metrics (
69
73
results , label , num_classes , ignore_index , metrics = 'mIoU' )
70
74
all_acc_l , acc_l , iou_l = legacy_mean_iou (results , label , num_classes ,
71
75
ignore_index )
72
76
assert all_acc == all_acc_l
73
77
assert np .allclose (acc , acc_l )
74
78
assert np .allclose (iou , iou_l )
75
-
79
+ # Test the correctness of the implementation of mDice calculation.
76
80
all_acc , acc , dice = eval_metrics (
77
81
results , label , num_classes , ignore_index , metrics = 'mDice' )
78
82
all_acc_l , acc_l , dice_l = legacy_mean_dice (results , label , num_classes ,
79
83
ignore_index )
80
84
assert all_acc == all_acc_l
81
85
assert np .allclose (acc , acc_l )
82
86
assert np .allclose (dice , dice_l )
83
-
87
+ # Test the correctness of the implementation of joint calculation.
84
88
all_acc , acc , iou , dice = eval_metrics (
85
89
results , label , num_classes , ignore_index , metrics = ['mIoU' , 'mDice' ])
86
90
assert all_acc == all_acc_l
87
91
assert np .allclose (acc , acc_l )
88
92
assert np .allclose (iou , iou_l )
89
93
assert np .allclose (dice , dice_l )
90
94
95
+ # Test the correctness of calculation when arg: num_classes is larger
96
+ # than the maximum value of input maps.
91
97
results = np .random .randint (0 , 5 , size = pred_size )
92
98
label = np .random .randint (0 , 4 , size = pred_size )
93
99
all_acc , acc , iou = eval_metrics (
@@ -121,6 +127,17 @@ def test_metrics():
121
127
assert dice [- 1 ] == - 1
122
128
assert iou [- 1 ] == - 1
123
129
130
+ # Test the bug which is caused by torch.histc.
131
+ # torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
132
+ # When the arg:bins is set to be same as arg:max,
133
+ # some channels of mIoU may be nan.
134
+ results = np .array ([np .repeat (31 , 59 )])
135
+ label = np .array ([np .arange (59 )])
136
+ num_classes = 59
137
+ all_acc , acc , iou = eval_metrics (
138
+ results , label , num_classes , ignore_index = 255 , metrics = 'mIoU' )
139
+ assert not np .any (np .isnan (iou ))
140
+
124
141
125
142
def test_mean_iou ():
126
143
pred_size = (10 , 30 , 30 )
@@ -182,7 +199,7 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
182
199
filenames .append (filename )
183
200
return filenames
184
201
185
- pred_size = (10 , 512 , 1024 )
202
+ pred_size = (10 , 30 , 30 )
186
203
num_classes = 19
187
204
ignore_index = 255
188
205
results = np .random .randint (0 , num_classes , size = pred_size )
0 commit comments