@@ -34,14 +34,16 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
3434 return area_intersect , area_union , area_pred_label , area_label
3535
3636
37- def mean_iou (results , gt_seg_maps , num_classes , ignore_index ):
37+ def mean_iou (results , gt_seg_maps , num_classes , ignore_index , nan_to_num = None ):
3838 """Calculate Intersection and Union (IoU)
3939
4040 Args:
4141 results (list[ndarray]): List of prediction segmentation maps
4242 gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
4343 num_classes (int): Number of categories
4444 ignore_index (int): Index that will be ignored in evaluation.
45+ nan_to_num (int, optional): If specified, NaN values will be replaced
46+ by the numbers defined by the user. Default: None.
4547
4648 Returns:
4749 float: Overall accuracy on all images.
@@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
6668 all_acc = total_area_intersect .sum () / total_area_label .sum ()
6769 acc = total_area_intersect / total_area_label
6870 iou = total_area_intersect / total_area_union
69-
71+ if nan_to_num is not None :
72+ return all_acc , np .nan_to_num (acc , nan = nan_to_num ), \
73+ np .nan_to_num (iou , nan = nan_to_num )
7074 return all_acc , acc , iou
0 commit comments