Skip to content

Commit 993be25

Browse files
authored
add dice evaluation metric (open-mmlab#225)
* add dice evaluation metric * add dice evaluation metric * add dice evaluation metric * support 2 metrics * support 2 metrics * support 2 metrics * support 2 metrics * fix docstring * use np.round once for all
1 parent 90e8e38 commit 993be25

File tree

9 files changed

+420
-179
lines changed

9 files changed

+420
-179
lines changed

mmseg/core/evaluation/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .class_names import get_classes, get_palette
22
from .eval_hooks import DistEvalHook, EvalHook
3-
from .mean_iou import mean_iou
3+
from .metrics import eval_metrics, mean_dice, mean_iou
44

55
__all__ = [
6-
'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette'
6+
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
7+
'get_classes', 'get_palette'
78
]

mmseg/core/evaluation/mean_iou.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

mmseg/core/evaluation/metrics.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import numpy as np
2+
3+
4+
def intersect_and_union(pred_label, label, num_classes, ignore_index):
5+
"""Calculate intersection and Union.
6+
7+
Args:
8+
pred_label (ndarray): Prediction segmentation map
9+
label (ndarray): Ground truth segmentation map
10+
num_classes (int): Number of categories
11+
ignore_index (int): Index that will be ignored in evaluation.
12+
13+
Returns:
14+
ndarray: The intersection of prediction and ground truth histogram
15+
on all classes
16+
ndarray: The union of prediction and ground truth histogram on all
17+
classes
18+
ndarray: The prediction histogram on all classes.
19+
ndarray: The ground truth histogram on all classes.
20+
"""
21+
22+
mask = (label != ignore_index)
23+
pred_label = pred_label[mask]
24+
label = label[mask]
25+
26+
intersect = pred_label[pred_label == label]
27+
area_intersect, _ = np.histogram(
28+
intersect, bins=np.arange(num_classes + 1))
29+
area_pred_label, _ = np.histogram(
30+
pred_label, bins=np.arange(num_classes + 1))
31+
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
32+
area_union = area_pred_label + area_label - area_intersect
33+
34+
return area_intersect, area_union, area_pred_label, area_label
35+
36+
37+
def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
38+
"""Calculate Total Intersection and Union.
39+
40+
Args:
41+
results (list[ndarray]): List of prediction segmentation maps
42+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
43+
num_classes (int): Number of categories
44+
ignore_index (int): Index that will be ignored in evaluation.
45+
46+
Returns:
47+
ndarray: The intersection of prediction and ground truth histogram
48+
on all classes
49+
ndarray: The union of prediction and ground truth histogram on all
50+
classes
51+
ndarray: The prediction histogram on all classes.
52+
ndarray: The ground truth histogram on all classes.
53+
"""
54+
55+
num_imgs = len(results)
56+
assert len(gt_seg_maps) == num_imgs
57+
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
58+
total_area_union = np.zeros((num_classes, ), dtype=np.float)
59+
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
60+
total_area_label = np.zeros((num_classes, ), dtype=np.float)
61+
for i in range(num_imgs):
62+
area_intersect, area_union, area_pred_label, area_label = \
63+
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
64+
ignore_index=ignore_index)
65+
total_area_intersect += area_intersect
66+
total_area_union += area_union
67+
total_area_pred_label += area_pred_label
68+
total_area_label += area_label
69+
return total_area_intersect, total_area_union, \
70+
total_area_pred_label, total_area_label
71+
72+
73+
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
74+
"""Calculate Mean Intersection and Union (mIoU)
75+
76+
Args:
77+
results (list[ndarray]): List of prediction segmentation maps
78+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
79+
num_classes (int): Number of categories
80+
ignore_index (int): Index that will be ignored in evaluation.
81+
nan_to_num (int, optional): If specified, NaN values will be replaced
82+
by the numbers defined by the user. Default: None.
83+
84+
Returns:
85+
float: Overall accuracy on all images.
86+
ndarray: Per category accuracy, shape (num_classes, )
87+
ndarray: Per category IoU, shape (num_classes, )
88+
"""
89+
90+
all_acc, acc, iou = eval_metrics(
91+
results=results,
92+
gt_seg_maps=gt_seg_maps,
93+
num_classes=num_classes,
94+
ignore_index=ignore_index,
95+
metrics=['mIoU'],
96+
nan_to_num=nan_to_num)
97+
return all_acc, acc, iou
98+
99+
100+
def mean_dice(results,
101+
gt_seg_maps,
102+
num_classes,
103+
ignore_index,
104+
nan_to_num=None):
105+
"""Calculate Mean Dice (mDice)
106+
107+
Args:
108+
results (list[ndarray]): List of prediction segmentation maps
109+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
110+
num_classes (int): Number of categories
111+
ignore_index (int): Index that will be ignored in evaluation.
112+
nan_to_num (int, optional): If specified, NaN values will be replaced
113+
by the numbers defined by the user. Default: None.
114+
115+
Returns:
116+
float: Overall accuracy on all images.
117+
ndarray: Per category accuracy, shape (num_classes, )
118+
ndarray: Per category dice, shape (num_classes, )
119+
"""
120+
121+
all_acc, acc, dice = eval_metrics(
122+
results=results,
123+
gt_seg_maps=gt_seg_maps,
124+
num_classes=num_classes,
125+
ignore_index=ignore_index,
126+
metrics=['mDice'],
127+
nan_to_num=nan_to_num)
128+
return all_acc, acc, dice
129+
130+
131+
def eval_metrics(results,
132+
gt_seg_maps,
133+
num_classes,
134+
ignore_index,
135+
metrics=['mIoU'],
136+
nan_to_num=None):
137+
"""Calculate evaluation metrics
138+
Args:
139+
results (list[ndarray]): List of prediction segmentation maps
140+
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
141+
num_classes (int): Number of categories
142+
ignore_index (int): Index that will be ignored in evaluation.
143+
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
144+
nan_to_num (int, optional): If specified, NaN values will be replaced
145+
by the numbers defined by the user. Default: None.
146+
Returns:
147+
float: Overall accuracy on all images.
148+
ndarray: Per category accuracy, shape (num_classes, )
149+
ndarray: Per category evalution metrics, shape (num_classes, )
150+
"""
151+
152+
if isinstance(metrics, str):
153+
metrics = [metrics]
154+
allowed_metrics = ['mIoU', 'mDice']
155+
if not set(metrics).issubset(set(allowed_metrics)):
156+
raise KeyError('metrics {} is not supported'.format(metrics))
157+
total_area_intersect, total_area_union, total_area_pred_label, \
158+
total_area_label = total_intersect_and_union(results, gt_seg_maps,
159+
num_classes,
160+
ignore_index=ignore_index)
161+
all_acc = total_area_intersect.sum() / total_area_label.sum()
162+
acc = total_area_intersect / total_area_label
163+
ret_metrics = [all_acc, acc]
164+
for metric in metrics:
165+
if metric == 'mIoU':
166+
iou = total_area_intersect / total_area_union
167+
ret_metrics.append(iou)
168+
elif metric == 'mDice':
169+
dice = 2 * total_area_intersect / (
170+
total_area_pred_label + total_area_label)
171+
ret_metrics.append(dice)
172+
if nan_to_num is not None:
173+
ret_metrics = [
174+
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
175+
]
176+
return ret_metrics

mmseg/datasets/custom.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
import mmcv
55
import numpy as np
66
from mmcv.utils import print_log
7+
from terminaltables import AsciiTable
78
from torch.utils.data import Dataset
89

9-
from mmseg.core import mean_iou
10+
from mmseg.core import eval_metrics
1011
from mmseg.utils import get_root_logger
1112
from .builder import DATASETS
1213
from .pipelines import Compose
1314

1415

1516
@DATASETS.register_module()
1617
class CustomDataset(Dataset):
17-
"""Custom dataset for semantic segmentation.
18-
19-
An example of file structure is as followed.
18+
"""Custom dataset for semantic segmentation. An example of file structure
19+
is as followed.
2020
2121
.. code-block:: none
2222
@@ -315,57 +315,63 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
315315
316316
Args:
317317
results (list): Testing results of the dataset.
318-
metric (str | list[str]): Metrics to be evaluated.
318+
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
319+
'mDice' are supported.
319320
logger (logging.Logger | None | str): Logger used for printing
320321
related information during evaluation. Default: None.
321322
322323
Returns:
323324
dict[str, float]: Default metrics.
324325
"""
325326

326-
if not isinstance(metric, str):
327-
assert len(metric) == 1
328-
metric = metric[0]
329-
allowed_metrics = ['mIoU']
330-
if metric not in allowed_metrics:
327+
if isinstance(metric, str):
328+
metric = [metric]
329+
allowed_metrics = ['mIoU', 'mDice']
330+
if not set(metric).issubset(set(allowed_metrics)):
331331
raise KeyError('metric {} is not supported'.format(metric))
332-
333332
eval_results = {}
334333
gt_seg_maps = self.get_gt_seg_maps()
335334
if self.CLASSES is None:
336335
num_classes = len(
337336
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
338337
else:
339338
num_classes = len(self.CLASSES)
340-
341-
all_acc, acc, iou = mean_iou(
342-
results, gt_seg_maps, num_classes, ignore_index=self.ignore_index)
343-
summary_str = ''
344-
summary_str += 'per class results:\n'
345-
346-
line_format = '{:<15} {:>10} {:>10}\n'
347-
summary_str += line_format.format('Class', 'IoU', 'Acc')
339+
ret_metrics = eval_metrics(
340+
results,
341+
gt_seg_maps,
342+
num_classes,
343+
ignore_index=self.ignore_index,
344+
metrics=metric)
345+
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
348346
if self.CLASSES is None:
349347
class_names = tuple(range(num_classes))
350348
else:
351349
class_names = self.CLASSES
350+
ret_metrics_round = [
351+
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
352+
]
352353
for i in range(num_classes):
353-
iou_str = '{:.2f}'.format(iou[i] * 100)
354-
acc_str = '{:.2f}'.format(acc[i] * 100)
355-
summary_str += line_format.format(class_names[i], iou_str, acc_str)
356-
summary_str += 'Summary:\n'
357-
line_format = '{:<15} {:>10} {:>10} {:>10}\n'
358-
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')
359-
360-
iou_str = '{:.2f}'.format(np.nanmean(iou) * 100)
361-
acc_str = '{:.2f}'.format(np.nanmean(acc) * 100)
362-
all_acc_str = '{:.2f}'.format(all_acc * 100)
363-
summary_str += line_format.format('global', iou_str, acc_str,
364-
all_acc_str)
365-
print_log(summary_str, logger)
366-
367-
eval_results['mIoU'] = np.nanmean(iou)
368-
eval_results['mAcc'] = np.nanmean(acc)
369-
eval_results['aAcc'] = all_acc
370-
354+
class_table_data.append([class_names[i]] +
355+
[m[i] for m in ret_metrics_round[2:]] +
356+
[ret_metrics_round[1][i]])
357+
summary_table_data = [['Scope'] +
358+
['m' + head
359+
for head in class_table_data[0][1:]] + ['aAcc']]
360+
ret_metrics_mean = [
361+
np.round(np.nanmean(ret_metric) * 100, 2)
362+
for ret_metric in ret_metrics
363+
]
364+
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
365+
[ret_metrics_mean[1]] +
366+
[ret_metrics_mean[0]])
367+
print_log('per class results:', logger)
368+
table = AsciiTable(class_table_data)
369+
print_log('\n' + table.table, logger=logger)
370+
print_log('Summary:', logger)
371+
table = AsciiTable(summary_table_data)
372+
print_log('\n' + table.table, logger=logger)
373+
374+
for i in range(1, len(summary_table_data[0])):
375+
eval_results[summary_table_data[0]
376+
[i]] = summary_table_data[1][i] / 100.0
371377
return eval_results

requirements/runtime.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
matplotlib
22
numpy
3+
terminaltables

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ line_length = 79
88
multi_line_output = 0
99
known_standard_library = setuptools
1010
known_first_party = mmseg
11-
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch
11+
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch
1212
no_lines_before = STDLIB,LOCALFOLDER
1313
default_section = THIRDPARTY

0 commit comments

Comments
 (0)