Skip to content

Commit 5a7996d

Browse files
authored
[Enhancement] Support loading GT for evaluation from multi-file backend (open-mmlab#867)
* support load gt for evaluation from multi-backend * move some code from get_gt_seg_maps to get_one_gt_seg_map * rename gt_seg_map_loader_conf to gt_seg_map_loader_cfg * fix doc str * rename get_one_gt_seg_map to get_gt_seg_map_by_idx
1 parent 56e18ba commit 5a7996d

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

mmseg/datasets/custom.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
1313
from mmseg.utils import get_root_logger
1414
from .builder import DATASETS
15-
from .pipelines import Compose
15+
from .pipelines import Compose, LoadAnnotations
1616

1717

1818
@DATASETS.register_module()
@@ -66,6 +66,8 @@ class CustomDataset(Dataset):
6666
The palette of segmentation map. If None is given, and
6767
self.PALETTE is None, random palette will be generated.
6868
Default: None
69+
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
70+
load gt for evaluation, load from disk by default. Default: None.
6971
"""
7072

7173
CLASSES = None
@@ -84,7 +86,8 @@ def __init__(self,
8486
ignore_index=255,
8587
reduce_zero_label=False,
8688
classes=None,
87-
palette=None):
89+
palette=None,
90+
gt_seg_map_loader_cfg=None):
8891
self.pipeline = Compose(pipeline)
8992
self.img_dir = img_dir
9093
self.img_suffix = img_suffix
@@ -98,6 +101,10 @@ def __init__(self,
98101
self.label_map = None
99102
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
100103
classes, palette)
104+
self.gt_seg_map_loader = LoadAnnotations(
105+
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
106+
**gt_seg_map_loader_cfg)
107+
101108
if test_mode:
102109
assert self.CLASSES is not None, \
103110
'`cls.CLASSES` or `classes` should be specified when testing'
@@ -232,6 +239,14 @@ def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
232239
"""Place holder to format result to dataset specific output."""
233240
raise NotImplementedError
234241

242+
def get_gt_seg_map_by_idx(self, index):
243+
"""Get one ground truth segmentation map for evaluation."""
244+
ann_info = self.get_ann_info(index)
245+
results = dict(ann_info=ann_info)
246+
self.pre_pipeline(results)
247+
self.gt_seg_map_loader(results)
248+
return results['gt_semantic_seg']
249+
235250
def get_gt_seg_maps(self, efficient_test=None):
236251
"""Get ground truth segmentation maps for evaluation."""
237252
if efficient_test is not None:
@@ -240,11 +255,12 @@ def get_gt_seg_maps(self, efficient_test=None):
240255
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
241256
'friendly by default. ')
242257

243-
for img_info in self.img_infos:
244-
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
245-
gt_seg_map = mmcv.imread(
246-
seg_map, flag='unchanged', backend='pillow')
247-
yield gt_seg_map
258+
for idx in range(len(self)):
259+
ann_info = self.get_ann_info(idx)
260+
results = dict(ann_info=ann_info)
261+
self.pre_pipeline(results)
262+
self.gt_seg_map_loader(results)
263+
yield results['gt_semantic_seg']
248264

249265
def pre_eval(self, preds, indices):
250266
"""Collect eval result from each iteration.
@@ -268,9 +284,7 @@ def pre_eval(self, preds, indices):
268284
pre_eval_results = []
269285

270286
for pred, index in zip(preds, indices):
271-
seg_map = osp.join(self.ann_dir,
272-
self.img_infos[index]['ann']['seg_map'])
273-
seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')
287+
seg_map = self.get_gt_seg_map_by_idx(index)
274288
pre_eval_results.append(
275289
intersect_and_union(pred, seg_map, len(self.CLASSES),
276290
self.ignore_index, self.label_map,

0 commit comments

Comments
 (0)