1212from mmseg .core import eval_metrics , intersect_and_union , pre_eval_to_metrics
1313from mmseg .utils import get_root_logger
1414from .builder import DATASETS
15- from .pipelines import Compose
15+ from .pipelines import Compose , LoadAnnotations
1616
1717
1818@DATASETS .register_module ()
@@ -66,6 +66,8 @@ class CustomDataset(Dataset):
6666 The palette of segmentation map. If None is given, and
6767 self.PALETTE is None, random palette will be generated.
6868 Default: None
69+ gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
70+ load gt for evaluation, load from disk by default. Default: None.
6971 """
7072
7173 CLASSES = None
@@ -84,7 +86,8 @@ def __init__(self,
8486 ignore_index = 255 ,
8587 reduce_zero_label = False ,
8688 classes = None ,
87- palette = None ):
89+ palette = None ,
90+ gt_seg_map_loader_cfg = None ):
8891 self .pipeline = Compose (pipeline )
8992 self .img_dir = img_dir
9093 self .img_suffix = img_suffix
@@ -98,6 +101,10 @@ def __init__(self,
98101 self .label_map = None
99102 self .CLASSES , self .PALETTE = self .get_classes_and_palette (
100103 classes , palette )
104+ self .gt_seg_map_loader = LoadAnnotations (
105+ ) if gt_seg_map_loader_cfg is None else LoadAnnotations (
106+ ** gt_seg_map_loader_cfg )
107+
101108 if test_mode :
102109 assert self .CLASSES is not None , \
103110 '`cls.CLASSES` or `classes` should be specified when testing'
@@ -232,6 +239,14 @@ def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
232239 """Place holder to format result to dataset specific output."""
233240 raise NotImplementedError
234241
242+ def get_gt_seg_map_by_idx (self , index ):
243+ """Get one ground truth segmentation map for evaluation."""
244+ ann_info = self .get_ann_info (index )
245+ results = dict (ann_info = ann_info )
246+ self .pre_pipeline (results )
247+ self .gt_seg_map_loader (results )
248+ return results ['gt_semantic_seg' ]
249+
235250 def get_gt_seg_maps (self , efficient_test = None ):
236251 """Get ground truth segmentation maps for evaluation."""
237252 if efficient_test is not None :
@@ -240,11 +255,12 @@ def get_gt_seg_maps(self, efficient_test=None):
240255 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
241256 'friendly by default. ' )
242257
243- for img_info in self .img_infos :
244- seg_map = osp .join (self .ann_dir , img_info ['ann' ]['seg_map' ])
245- gt_seg_map = mmcv .imread (
246- seg_map , flag = 'unchanged' , backend = 'pillow' )
247- yield gt_seg_map
258+ for idx in range (len (self )):
259+ ann_info = self .get_ann_info (idx )
260+ results = dict (ann_info = ann_info )
261+ self .pre_pipeline (results )
262+ self .gt_seg_map_loader (results )
263+ yield results ['gt_semantic_seg' ]
248264
249265 def pre_eval (self , preds , indices ):
250266 """Collect eval result from each iteration.
@@ -268,9 +284,7 @@ def pre_eval(self, preds, indices):
268284 pre_eval_results = []
269285
270286 for pred , index in zip (preds , indices ):
271- seg_map = osp .join (self .ann_dir ,
272- self .img_infos [index ]['ann' ]['seg_map' ])
273- seg_map = mmcv .imread (seg_map , flag = 'unchanged' , backend = 'pillow' )
287+ seg_map = self .get_gt_seg_map_by_idx (index )
274288 pre_eval_results .append (
275289 intersect_and_union (pred , seg_map , len (self .CLASSES ),
276290 self .ignore_index , self .label_map ,
0 commit comments