@@ -31,7 +31,7 @@ def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
3131def legacy_mean_iou (results , gt_seg_maps , num_classes , ignore_index ):
3232 num_imgs = len (results )
3333 assert len (gt_seg_maps ) == num_imgs
34- total_mat = np .zeros ((num_classes , num_classes ), dtype = np .float )
34+ total_mat = np .zeros ((num_classes , num_classes ), dtype = np .float32 )
3535 for i in range (num_imgs ):
3636 mat = get_confusion_matrix (
3737 results [i ], gt_seg_maps [i ], num_classes , ignore_index = ignore_index )
@@ -48,7 +48,7 @@ def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index):
4848def legacy_mean_dice (results , gt_seg_maps , num_classes , ignore_index ):
4949 num_imgs = len (results )
5050 assert len (gt_seg_maps ) == num_imgs
51- total_mat = np .zeros ((num_classes , num_classes ), dtype = np .float )
51+ total_mat = np .zeros ((num_classes , num_classes ), dtype = np .float32 )
5252 for i in range (num_imgs ):
5353 mat = get_confusion_matrix (
5454 results [i ], gt_seg_maps [i ], num_classes , ignore_index = ignore_index )
@@ -69,7 +69,7 @@ def legacy_mean_fscore(results,
6969 beta = 1 ):
7070 num_imgs = len (results )
7171 assert len (gt_seg_maps ) == num_imgs
72- total_mat = np .zeros ((num_classes , num_classes ), dtype = np .float )
72+ total_mat = np .zeros ((num_classes , num_classes ), dtype = np .float32 )
7373 for i in range (num_imgs ):
7474 mat = get_confusion_matrix (
7575 results [i ], gt_seg_maps [i ], num_classes , ignore_index = ignore_index )
@@ -100,7 +100,7 @@ def test_metrics():
100100 'IoU' ]
101101 all_acc_l , acc_l , iou_l = legacy_mean_iou (results , label , num_classes ,
102102 ignore_index )
103- assert all_acc == all_acc_l
103+ assert np . allclose ( all_acc , all_acc_l )
104104 assert np .allclose (acc , acc_l )
105105 assert np .allclose (iou , iou_l )
106106 # Test the correctness of the implementation of mDice calculation.
@@ -110,7 +110,7 @@ def test_metrics():
110110 'Dice' ]
111111 all_acc_l , acc_l , dice_l = legacy_mean_dice (results , label , num_classes ,
112112 ignore_index )
113- assert all_acc == all_acc_l
113+ assert np . allclose ( all_acc , all_acc_l )
114114 assert np .allclose (acc , acc_l )
115115 assert np .allclose (dice , dice_l )
116116 # Test the correctness of the implementation of mDice calculation.
@@ -120,7 +120,7 @@ def test_metrics():
120120 'Recall' ], ret_metrics ['Precision' ], ret_metrics ['Fscore' ]
121121 all_acc_l , recall_l , precision_l , fscore_l = legacy_mean_fscore (
122122 results , label , num_classes , ignore_index )
123- assert all_acc == all_acc_l
123+ assert np . allclose ( all_acc , all_acc_l )
124124 assert np .allclose (recall , recall_l )
125125 assert np .allclose (precision , precision_l )
126126 assert np .allclose (fscore , fscore_l )
@@ -135,7 +135,7 @@ def test_metrics():
135135 'aAcc' ], ret_metrics ['Acc' ], ret_metrics ['IoU' ], ret_metrics [
136136 'Dice' ], ret_metrics ['Precision' ], ret_metrics [
137137 'Recall' ], ret_metrics ['Fscore' ]
138- assert all_acc == all_acc_l
138+ assert np . allclose ( all_acc , all_acc_l )
139139 assert np .allclose (acc , acc_l )
140140 assert np .allclose (iou , iou_l )
141141 assert np .allclose (dice , dice_l )
@@ -228,7 +228,7 @@ def test_mean_iou():
228228 'IoU' ]
229229 all_acc_l , acc_l , iou_l = legacy_mean_iou (results , label , num_classes ,
230230 ignore_index )
231- assert all_acc == all_acc_l
231+ assert np . allclose ( all_acc , all_acc_l )
232232 assert np .allclose (acc , acc_l )
233233 assert np .allclose (iou , iou_l )
234234
@@ -254,7 +254,7 @@ def test_mean_dice():
254254 'Dice' ]
255255 all_acc_l , acc_l , dice_l = legacy_mean_dice (results , label , num_classes ,
256256 ignore_index )
257- assert all_acc == all_acc_l
257+ assert np . allclose ( all_acc , all_acc_l )
258258 assert np .allclose (acc , acc_l )
259259 assert np .allclose (iou , dice_l )
260260
@@ -280,7 +280,7 @@ def test_mean_fscore():
280280 'Recall' ], ret_metrics ['Precision' ], ret_metrics ['Fscore' ]
281281 all_acc_l , recall_l , precision_l , fscore_l = legacy_mean_fscore (
282282 results , label , num_classes , ignore_index )
283- assert all_acc == all_acc_l
283+ assert np . allclose ( all_acc , all_acc_l )
284284 assert np .allclose (recall , recall_l )
285285 assert np .allclose (precision , precision_l )
286286 assert np .allclose (fscore , fscore_l )
@@ -291,7 +291,7 @@ def test_mean_fscore():
291291 'Recall' ], ret_metrics ['Precision' ], ret_metrics ['Fscore' ]
292292 all_acc_l , recall_l , precision_l , fscore_l = legacy_mean_fscore (
293293 results , label , num_classes , ignore_index , beta = 2 )
294- assert all_acc == all_acc_l
294+ assert np . allclose ( all_acc , all_acc_l )
295295 assert np .allclose (recall , recall_l )
296296 assert np .allclose (precision , precision_l )
297297 assert np .allclose (fscore , fscore_l )
@@ -346,6 +346,6 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
346346 'Acc' ], ret_metrics ['IoU' ]
347347 all_acc_l , acc_l , iou_l = legacy_mean_iou (results , labels , num_classes ,
348348 ignore_index )
349- assert all_acc == all_acc_l
349+ assert np . allclose ( all_acc , all_acc_l )
350350 assert np .allclose (acc , acc_l )
351351 assert np .allclose (iou , iou_l )
0 commit comments