@@ -64,30 +64,36 @@ def test_metrics():
6464 ignore_index = 255
6565 results = np .random .randint (0 , num_classes , size = pred_size )
6666 label = np .random .randint (0 , num_classes , size = pred_size )
67+
68+ # Test the availability of arg: ignore_index.
6769 label [:, 2 , 5 :10 ] = ignore_index
70+
71+ # Test the correctness of the implementation of mIoU calculation.
6872 all_acc , acc , iou = eval_metrics (
6973 results , label , num_classes , ignore_index , metrics = 'mIoU' )
7074 all_acc_l , acc_l , iou_l = legacy_mean_iou (results , label , num_classes ,
7175 ignore_index )
7276 assert all_acc == all_acc_l
7377 assert np .allclose (acc , acc_l )
7478 assert np .allclose (iou , iou_l )
75-
79+ # Test the correctness of the implementation of mDice calculation.
7680 all_acc , acc , dice = eval_metrics (
7781 results , label , num_classes , ignore_index , metrics = 'mDice' )
7882 all_acc_l , acc_l , dice_l = legacy_mean_dice (results , label , num_classes ,
7983 ignore_index )
8084 assert all_acc == all_acc_l
8185 assert np .allclose (acc , acc_l )
8286 assert np .allclose (dice , dice_l )
83-
87+ # Test the correctness of the implementation of joint calculation.
8488 all_acc , acc , iou , dice = eval_metrics (
8589 results , label , num_classes , ignore_index , metrics = ['mIoU' , 'mDice' ])
8690 assert all_acc == all_acc_l
8791 assert np .allclose (acc , acc_l )
8892 assert np .allclose (iou , iou_l )
8993 assert np .allclose (dice , dice_l )
9094
95+ # Test the correctness of calculation when arg: num_classes is larger
96+ # than the maximum value of input maps.
9197 results = np .random .randint (0 , 5 , size = pred_size )
9298 label = np .random .randint (0 , 4 , size = pred_size )
9399 all_acc , acc , iou = eval_metrics (
@@ -121,6 +127,17 @@ def test_metrics():
121127 assert dice [- 1 ] == - 1
122128 assert iou [- 1 ] == - 1
123129
130+ # Test the bug which is caused by torch.histc.
131+ # torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
132+ # When the arg:bins is set to be same as arg:max,
133+ # some channels of mIoU may be nan.
134+ results = np .array ([np .repeat (31 , 59 )])
135+ label = np .array ([np .arange (59 )])
136+ num_classes = 59
137+ all_acc , acc , iou = eval_metrics (
138+ results , label , num_classes , ignore_index = 255 , metrics = 'mIoU' )
139+ assert not np .any (np .isnan (iou ))
140+
124141
125142def test_mean_iou ():
126143 pred_size = (10 , 30 , 30 )
@@ -182,7 +199,7 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
182199 filenames .append (filename )
183200 return filenames
184201
185- pred_size = (10 , 512 , 1024 )
202+ pred_size = (10 , 30 , 30 )
186203 num_classes = 19
187204 ignore_index = 255
188205 results = np .random .randint (0 , num_classes , size = pred_size )
0 commit comments