11import mmcv
22import numpy as np
3+ import torch
34
45
56def intersect_and_union (pred_label ,
@@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
1112 """Calculate intersection and Union.
1213
1314 Args:
14- pred_label (ndarray): Prediction segmentation map.
15- label (ndarray): Ground truth segmentation map.
15+ pred_label (ndarray | str): Prediction segmentation map
16+ or predict result filename.
17+ label (ndarray | str): Ground truth segmentation map
18+ or label filename.
1619 num_classes (int): Number of categories.
1720 ignore_index (int): Index that will be ignored in evaluation.
1821 label_map (dict): Mapping old labels to new labels. The parameter will
@@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
2124 work only when label is str. Default: False.
2225
2326 Returns:
24- ndarray : The intersection of prediction and ground truth histogram
25- on all classes.
26- ndarray : The union of prediction and ground truth histogram on all
27- classes.
28- ndarray : The prediction histogram on all classes.
29- ndarray : The ground truth histogram on all classes.
27+ torch.Tensor : The intersection of prediction and ground truth
28+ histogram on all classes.
29+ torch.Tensor : The union of prediction and ground truth histogram on
30+ all classes.
31+ torch.Tensor : The prediction histogram on all classes.
32+ torch.Tensor : The ground truth histogram on all classes.
3033 """
3134
3235 if isinstance (pred_label , str ):
33- pred_label = np .load (pred_label )
36+ pred_label = torch .from_numpy (np .load (pred_label ))
37+ else :
38+ pred_label = torch .from_numpy ((pred_label ))
3439
3540 if isinstance (label , str ):
36- label = mmcv .imread (label , flag = 'unchanged' , backend = 'pillow' )
37- # modify if custom classes
41+ label = torch .from_numpy (
42+ mmcv .imread (label , flag = 'unchanged' , backend = 'pillow' ))
43+ else :
44+ label = torch .from_numpy (label )
45+
3846 if label_map is not None :
3947 for old_id , new_id in label_map .items ():
4048 label [label == old_id ] = new_id
4149 if reduce_zero_label :
42- # avoid using underflow conversion
4350 label [label == 0 ] = 255
4451 label = label - 1
4552 label [label == 254 ] = 255
@@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
4956 label = label [mask ]
5057
5158 intersect = pred_label [pred_label == label ]
52- area_intersect , _ = np .histogram (
53- intersect , bins = np .arange (num_classes + 1 ))
54- area_pred_label , _ = np .histogram (
55- pred_label , bins = np .arange (num_classes + 1 ))
56- area_label , _ = np .histogram (label , bins = np .arange (num_classes + 1 ))
59+ area_intersect = torch .histc (
60+ intersect .float (), bins = (num_classes ), min = 0 , max = num_classes )
61+ area_pred_label = torch .histc (
62+ pred_label .float (), bins = (num_classes ), min = 0 , max = num_classes )
63+ area_label = torch .histc (
64+ label .float (), bins = (num_classes ), min = 0 , max = num_classes )
5765 area_union = area_pred_label + area_label - area_intersect
58-
5966 return area_intersect , area_union , area_pred_label , area_label
6067
6168
@@ -68,8 +75,10 @@ def total_intersect_and_union(results,
6875 """Calculate Total Intersection and Union.
6976
7077 Args:
71- results (list[ndarray]): List of prediction segmentation maps.
72- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
78+ results (list[ndarray] | list[str]): List of prediction segmentation
79+ maps or list of prediction result filenames.
80+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
81+ segmentation maps or list of label filenames.
7382 num_classes (int): Number of categories.
7483 ignore_index (int): Index that will be ignored in evaluation.
7584 label_map (dict): Mapping old labels to new labels. Default: dict().
@@ -83,23 +92,23 @@ def total_intersect_and_union(results,
8392 ndarray: The prediction histogram on all classes.
8493 ndarray: The ground truth histogram on all classes.
8594 """
86-
8795 num_imgs = len (results )
8896 assert len (gt_seg_maps ) == num_imgs
89- total_area_intersect = np .zeros ((num_classes , ), dtype = np . float )
90- total_area_union = np .zeros ((num_classes , ), dtype = np . float )
91- total_area_pred_label = np .zeros ((num_classes , ), dtype = np . float )
92- total_area_label = np .zeros ((num_classes , ), dtype = np . float )
97+ total_area_intersect = torch .zeros ((num_classes , ), dtype = torch . float64 )
98+ total_area_union = torch .zeros ((num_classes , ), dtype = torch . float64 )
99+ total_area_pred_label = torch .zeros ((num_classes , ), dtype = torch . float64 )
100+ total_area_label = torch .zeros ((num_classes , ), dtype = torch . float64 )
93101 for i in range (num_imgs ):
94102 area_intersect , area_union , area_pred_label , area_label = \
95- intersect_and_union (results [i ], gt_seg_maps [i ], num_classes ,
96- ignore_index , label_map , reduce_zero_label )
103+ intersect_and_union (
104+ results [i ], gt_seg_maps [i ], num_classes , ignore_index ,
105+ label_map , reduce_zero_label )
97106 total_area_intersect += area_intersect
98107 total_area_union += area_union
99108 total_area_pred_label += area_pred_label
100109 total_area_label += area_label
101- return total_area_intersect , total_area_union , \
102- total_area_pred_label , total_area_label
110+ return total_area_intersect , total_area_union , total_area_pred_label , \
111+ total_area_label
103112
104113
105114def mean_iou (results ,
@@ -112,8 +121,10 @@ def mean_iou(results,
112121 """Calculate Mean Intersection and Union (mIoU)
113122
114123 Args:
115- results (list[ndarray]): List of prediction segmentation maps.
116- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
124+ results (list[ndarray] | list[str]): List of prediction segmentation
125+ maps or list of prediction result filenames.
126+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
127+ segmentation maps or list of label filenames.
117128 num_classes (int): Number of categories.
118129 ignore_index (int): Index that will be ignored in evaluation.
119130 nan_to_num (int, optional): If specified, NaN values will be replaced
@@ -126,7 +137,6 @@ def mean_iou(results,
126137 ndarray: Per category accuracy, shape (num_classes, ).
127138 ndarray: Per category IoU, shape (num_classes, ).
128139 """
129-
130140 all_acc , acc , iou = eval_metrics (
131141 results = results ,
132142 gt_seg_maps = gt_seg_maps ,
@@ -149,8 +159,10 @@ def mean_dice(results,
149159 """Calculate Mean Dice (mDice)
150160
151161 Args:
152- results (list[ndarray]): List of prediction segmentation maps.
153- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
162+ results (list[ndarray] | list[str]): List of prediction segmentation
163+ maps or list of prediction result filenames.
164+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
165+ segmentation maps or list of label filenames.
154166 num_classes (int): Number of categories.
155167 ignore_index (int): Index that will be ignored in evaluation.
156168 nan_to_num (int, optional): If specified, NaN values will be replaced
@@ -186,8 +198,10 @@ def eval_metrics(results,
186198 reduce_zero_label = False ):
187199 """Calculate evaluation metrics
188200 Args:
189- results (list[ndarray]): List of prediction segmentation maps.
190- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
201+ results (list[ndarray] | list[str]): List of prediction segmentation
202+ maps or list of prediction result filenames.
203+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
204+ segmentation maps or list of label filenames.
191205 num_classes (int): Number of categories.
192206 ignore_index (int): Index that will be ignored in evaluation.
193207 metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
@@ -200,17 +214,16 @@ def eval_metrics(results,
200214 ndarray: Per category accuracy, shape (num_classes, ).
201215 ndarray: Per category evalution metrics, shape (num_classes, ).
202216 """
203-
204217 if isinstance (metrics , str ):
205218 metrics = [metrics ]
206219 allowed_metrics = ['mIoU' , 'mDice' ]
207220 if not set (metrics ).issubset (set (allowed_metrics )):
208221 raise KeyError ('metrics {} is not supported' .format (metrics ))
222+
209223 total_area_intersect , total_area_union , total_area_pred_label , \
210- total_area_label = total_intersect_and_union (results , gt_seg_maps ,
211- num_classes , ignore_index ,
212- label_map ,
213- reduce_zero_label )
224+ total_area_label = total_intersect_and_union (
225+ results , gt_seg_maps , num_classes , ignore_index , label_map ,
226+ reduce_zero_label )
214227 all_acc = total_area_intersect .sum () / total_area_label .sum ()
215228 acc = total_area_intersect / total_area_label
216229 ret_metrics = [all_acc , acc ]
@@ -222,6 +235,7 @@ def eval_metrics(results,
222235 dice = 2 * total_area_intersect / (
223236 total_area_pred_label + total_area_label )
224237 ret_metrics .append (dice )
238+ ret_metrics = [metric .numpy () for metric in ret_metrics ]
225239 if nan_to_num is not None :
226240 ret_metrics = [
227241 np .nan_to_num (metric , nan = nan_to_num ) for metric in ret_metrics
0 commit comments