@@ -34,14 +34,16 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
34
34
return area_intersect , area_union , area_pred_label , area_label
35
35
36
36
37
- def mean_iou (results , gt_seg_maps , num_classes , ignore_index ):
37
+ def mean_iou (results , gt_seg_maps , num_classes , ignore_index , nan_to_num = None ):
38
38
"""Calculate Intersection and Union (IoU)
39
39
40
40
Args:
41
41
results (list[ndarray]): List of prediction segmentation maps
42
42
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
43
43
num_classes (int): Number of categories
44
44
ignore_index (int): Index that will be ignored in evaluation.
45
+ nan_to_num (int, optional): If specified, NaN values will be replaced
46
+ by the numbers defined by the user. Default: None.
45
47
46
48
Returns:
47
49
float: Overall accuracy on all images.
@@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
66
68
all_acc = total_area_intersect .sum () / total_area_label .sum ()
67
69
acc = total_area_intersect / total_area_label
68
70
iou = total_area_intersect / total_area_union
69
-
71
+ if nan_to_num is not None :
72
+ return all_acc , np .nan_to_num (acc , nan = nan_to_num ), \
73
+ np .nan_to_num (iou , nan = nan_to_num )
70
74
return all_acc , acc , iou
0 commit comments