1+ from collections import OrderedDict
2+
13import mmcv
24import numpy as np
35import torch
46
57
8+ def f_score (precision , recall , beta = 1 ):
9+ """calcuate the f-score value.
10+
11+ Args:
12+ precision (float | torch.Tensor): The precision value.
13+ recall (float | torch.Tensor): The recall value.
14+ beta (int): Determines the weight of recall in the combined score.
15+ Default: False.
16+
17+ Returns:
18+ [torch.tensor]: The f-score value.
19+ """
20+ score = (1 + beta ** 2 ) * (precision * recall ) / (
21+ (beta ** 2 * precision ) + recall )
22+ return score
23+
24+
625def intersect_and_union (pred_label ,
726 label ,
827 num_classes ,
@@ -133,11 +152,12 @@ def mean_iou(results,
133152 reduce_zero_label (bool): Wether ignore zero label. Default: False.
134153
135154 Returns:
136- float: Overall accuracy on all images.
137- ndarray: Per category accuracy, shape (num_classes, ).
138- ndarray: Per category IoU, shape (num_classes, ).
155+ dict[str, float | ndarray]:
156+ <aAcc> float: Overall accuracy on all images.
157+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
158+ <IoU> ndarray: Per category IoU, shape (num_classes, ).
139159 """
140- all_acc , acc , iou = eval_metrics (
160+ iou_result = eval_metrics (
141161 results = results ,
142162 gt_seg_maps = gt_seg_maps ,
143163 num_classes = num_classes ,
@@ -146,7 +166,7 @@ def mean_iou(results,
146166 nan_to_num = nan_to_num ,
147167 label_map = label_map ,
148168 reduce_zero_label = reduce_zero_label )
149- return all_acc , acc , iou
169+ return iou_result
150170
151171
152172def mean_dice (results ,
@@ -171,12 +191,13 @@ def mean_dice(results,
171191 reduce_zero_label (bool): Wether ignore zero label. Default: False.
172192
173193 Returns:
174- float: Overall accuracy on all images.
175- ndarray: Per category accuracy, shape (num_classes, ).
176- ndarray: Per category dice, shape (num_classes, ).
194+ dict[str, float | ndarray]: Default metrics.
195+ <aAcc> float: Overall accuracy on all images.
196+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
197+ <Dice> ndarray: Per category dice, shape (num_classes, ).
177198 """
178199
179- all_acc , acc , dice = eval_metrics (
200+ dice_result = eval_metrics (
180201 results = results ,
181202 gt_seg_maps = gt_seg_maps ,
182203 num_classes = num_classes ,
@@ -185,7 +206,52 @@ def mean_dice(results,
185206 nan_to_num = nan_to_num ,
186207 label_map = label_map ,
187208 reduce_zero_label = reduce_zero_label )
188- return all_acc , acc , dice
209+ return dice_result
210+
211+
212+ def mean_fscore (results ,
213+ gt_seg_maps ,
214+ num_classes ,
215+ ignore_index ,
216+ nan_to_num = None ,
217+ label_map = dict (),
218+ reduce_zero_label = False ,
219+ beta = 1 ):
220+ """Calculate Mean Intersection and Union (mIoU)
221+
222+ Args:
223+ results (list[ndarray] | list[str]): List of prediction segmentation
224+ maps or list of prediction result filenames.
225+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
226+ segmentation maps or list of label filenames.
227+ num_classes (int): Number of categories.
228+ ignore_index (int): Index that will be ignored in evaluation.
229+ nan_to_num (int, optional): If specified, NaN values will be replaced
230+ by the numbers defined by the user. Default: None.
231+ label_map (dict): Mapping old labels to new labels. Default: dict().
232+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
233+ beta (int): Determines the weight of recall in the combined score.
234+ Default: False.
235+
236+
237+ Returns:
238+ dict[str, float | ndarray]: Default metrics.
239+ <aAcc> float: Overall accuracy on all images.
240+ <Fscore> ndarray: Per category recall, shape (num_classes, ).
241+ <Precision> ndarray: Per category precision, shape (num_classes, ).
242+ <Recall> ndarray: Per category f-score, shape (num_classes, ).
243+ """
244+ fscore_result = eval_metrics (
245+ results = results ,
246+ gt_seg_maps = gt_seg_maps ,
247+ num_classes = num_classes ,
248+ ignore_index = ignore_index ,
249+ metrics = ['mFscore' ],
250+ nan_to_num = nan_to_num ,
251+ label_map = label_map ,
252+ reduce_zero_label = reduce_zero_label ,
253+ beta = beta )
254+ return fscore_result
189255
190256
191257def eval_metrics (results ,
@@ -195,7 +261,8 @@ def eval_metrics(results,
195261 metrics = ['mIoU' ],
196262 nan_to_num = None ,
197263 label_map = dict (),
198- reduce_zero_label = False ):
264+ reduce_zero_label = False ,
265+ beta = 1 ):
199266 """Calculate evaluation metrics
200267 Args:
201268 results (list[ndarray] | list[str]): List of prediction segmentation
@@ -210,13 +277,13 @@ def eval_metrics(results,
210277 label_map (dict): Mapping old labels to new labels. Default: dict().
211278 reduce_zero_label (bool): Wether ignore zero label. Default: False.
212279 Returns:
213- float: Overall accuracy on all images.
214- ndarray: Per category accuracy, shape (num_classes, ).
215- ndarray: Per category evaluation metrics, shape (num_classes, ).
280+ float: Overall accuracy on all images.
281+ ndarray: Per category accuracy, shape (num_classes, ).
282+ ndarray: Per category evaluation metrics, shape (num_classes, ).
216283 """
217284 if isinstance (metrics , str ):
218285 metrics = [metrics ]
219- allowed_metrics = ['mIoU' , 'mDice' ]
286+ allowed_metrics = ['mIoU' , 'mDice' , 'mFscore' ]
220287 if not set (metrics ).issubset (set (allowed_metrics )):
221288 raise KeyError ('metrics {} is not supported' .format (metrics ))
222289
@@ -225,19 +292,35 @@ def eval_metrics(results,
225292 results , gt_seg_maps , num_classes , ignore_index , label_map ,
226293 reduce_zero_label )
227294 all_acc = total_area_intersect .sum () / total_area_label .sum ()
228- acc = total_area_intersect / total_area_label
229- ret_metrics = [all_acc , acc ]
295+ ret_metrics = OrderedDict ({'aAcc' : all_acc })
230296 for metric in metrics :
231297 if metric == 'mIoU' :
232298 iou = total_area_intersect / total_area_union
233- ret_metrics .append (iou )
299+ acc = total_area_intersect / total_area_label
300+ ret_metrics ['IoU' ] = iou
301+ ret_metrics ['Acc' ] = acc
234302 elif metric == 'mDice' :
235303 dice = 2 * total_area_intersect / (
236304 total_area_pred_label + total_area_label )
237- ret_metrics .append (dice )
238- ret_metrics = [metric .numpy () for metric in ret_metrics ]
305+ acc = total_area_intersect / total_area_label
306+ ret_metrics ['Dice' ] = dice
307+ ret_metrics ['Acc' ] = acc
308+ elif metric == 'mFscore' :
309+ precision = total_area_intersect / total_area_pred_label
310+ recall = total_area_intersect / total_area_label
311+ f_value = torch .tensor (
312+ [f_score (x [0 ], x [1 ], beta ) for x in zip (precision , recall )])
313+ ret_metrics ['Fscore' ] = f_value
314+ ret_metrics ['Precision' ] = precision
315+ ret_metrics ['Recall' ] = recall
316+
317+ ret_metrics = {
318+ metric : value .numpy ()
319+ for metric , value in ret_metrics .items ()
320+ }
239321 if nan_to_num is not None :
240- ret_metrics = [
241- np .nan_to_num (metric , nan = nan_to_num ) for metric in ret_metrics
242- ]
322+ ret_metrics = OrderedDict ({
323+ metric : np .nan_to_num (metric_value , nan = nan_to_num )
324+ for metric , metric_value in ret_metrics .items ()
325+ })
243326 return ret_metrics
0 commit comments