Skip to content

Commit 6c3e63e

Browse files
lkm2835Younghoon-LeeMeowZheng
authored
[Feature] Add MultiImageMixDataset (#1105)
* Fix typo in usage example * original MultiImageMixDataset code in mmdet * Add MultiImageMixDataset unittests in test_dataset_wrapper * fix lint error * fix value name ann_file to ann_dir * modify retrieve_data_cfg (#1) * remove dynamic_scale & add palette * modify retrieve_data_cfg method * modify retrieve_data_cfg func * fix error * improve the unittests coverage * fix unittests error * Dataset (#2) * add cfg-options * Add unittest in test_build_dataset * add blank line * add blank line * add a blank line Co-authored-by: Miao Zheng <[email protected]> Co-authored-by: Younghoon-Lee <[email protected]> Co-authored-by: MeowZheng <[email protected]> Co-authored-by: Miao Zheng <[email protected]>
1 parent f0262fa commit 6c3e63e

File tree

6 files changed

+190
-16
lines changed

6 files changed

+190
-16
lines changed

mmseg/datasets/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .coco_stuff import COCOStuffDataset
77
from .custom import CustomDataset
88
from .dark_zurich import DarkZurichDataset
9-
from .dataset_wrappers import ConcatDataset, RepeatDataset
9+
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
10+
RepeatDataset)
1011
from .drive import DRIVEDataset
1112
from .hrf import HRFDataset
1213
from .loveda import LoveDADataset
@@ -21,5 +22,5 @@
2122
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
2223
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
2324
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
24-
'COCOStuffDataset', 'LoveDADataset'
25+
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset'
2526
]

mmseg/datasets/builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@ def _concat_dataset(cfg, default_args=None):
6464

6565
def build_dataset(cfg, default_args=None):
6666
"""Build datasets."""
67-
from .dataset_wrappers import ConcatDataset, RepeatDataset
67+
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
68+
MultiImageMixDataset)
6869
if isinstance(cfg, (list, tuple)):
6970
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
7071
elif cfg['type'] == 'RepeatDataset':
7172
dataset = RepeatDataset(
7273
build_dataset(cfg['dataset'], default_args), cfg['times'])
74+
elif cfg['type'] == 'MultiImageMixDataset':
75+
cp_cfg = copy.deepcopy(cfg)
76+
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
77+
cp_cfg.pop('type')
78+
dataset = MultiImageMixDataset(**cp_cfg)
7379
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
7480
cfg.get('split', None), (list, tuple)):
7581
dataset = _concat_dataset(cfg, default_args)

mmseg/datasets/dataset_wrappers.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import bisect
3+
import collections
4+
import copy
35
from itertools import chain
46

57
import mmcv
68
import numpy as np
7-
from mmcv.utils import print_log
9+
from mmcv.utils import build_from_cfg, print_log
810
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
911

10-
from .builder import DATASETS
12+
from .builder import DATASETS, PIPELINES
1113
from .cityscapes import CityscapesDataset
1214

1315

@@ -188,3 +190,88 @@ def __getitem__(self, idx):
188190
def __len__(self):
189191
"""The length is multiplied by ``times``"""
190192
return self.times * self._ori_len
193+
194+
195+
@DATASETS.register_module()
196+
class MultiImageMixDataset:
197+
"""A wrapper of multiple images mixed dataset.
198+
199+
Suitable for training on multiple images mixed data augmentation like
200+
mosaic and mixup. For the augmentation pipeline of mixed image data,
201+
the `get_indexes` method needs to be provided to obtain the image
202+
indexes, and you can set `skip_flags` to change the pipeline running
203+
process.
204+
205+
206+
Args:
207+
dataset (:obj:`CustomDataset`): The dataset to be mixed.
208+
pipeline (Sequence[dict]): Sequence of transform object or
209+
config dict to be composed.
210+
skip_type_keys (list[str], optional): Sequence of type string to
211+
be skip pipeline. Default to None.
212+
"""
213+
214+
def __init__(self, dataset, pipeline, skip_type_keys=None):
215+
assert isinstance(pipeline, collections.abc.Sequence)
216+
if skip_type_keys is not None:
217+
assert all([
218+
isinstance(skip_type_key, str)
219+
for skip_type_key in skip_type_keys
220+
])
221+
self._skip_type_keys = skip_type_keys
222+
223+
self.pipeline = []
224+
self.pipeline_types = []
225+
for transform in pipeline:
226+
if isinstance(transform, dict):
227+
self.pipeline_types.append(transform['type'])
228+
transform = build_from_cfg(transform, PIPELINES)
229+
self.pipeline.append(transform)
230+
else:
231+
raise TypeError('pipeline must be a dict')
232+
233+
self.dataset = dataset
234+
self.CLASSES = dataset.CLASSES
235+
self.PALETTE = dataset.PALETTE
236+
self.num_samples = len(dataset)
237+
238+
def __len__(self):
239+
return self.num_samples
240+
241+
def __getitem__(self, idx):
242+
results = copy.deepcopy(self.dataset[idx])
243+
for (transform, transform_type) in zip(self.pipeline,
244+
self.pipeline_types):
245+
if self._skip_type_keys is not None and \
246+
transform_type in self._skip_type_keys:
247+
continue
248+
249+
if hasattr(transform, 'get_indexes'):
250+
indexes = transform.get_indexes(self.dataset)
251+
if not isinstance(indexes, collections.abc.Sequence):
252+
indexes = [indexes]
253+
mix_results = [
254+
copy.deepcopy(self.dataset[index]) for index in indexes
255+
]
256+
results['mix_results'] = mix_results
257+
258+
results = transform(results)
259+
260+
if 'mix_results' in results:
261+
results.pop('mix_results')
262+
263+
return results
264+
265+
def update_skip_type_keys(self, skip_type_keys):
266+
"""Update skip_type_keys.
267+
268+
It is called by an external hook.
269+
270+
Args:
271+
skip_type_keys (list[str], optional): Sequence of type
272+
string to be skip pipeline.
273+
"""
274+
assert all([
275+
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
276+
])
277+
self._skip_type_keys = skip_type_keys

tests/test_data/test_dataset.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from mmseg.core.evaluation import get_classes, get_palette
1515
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
1616
ConcatDataset, CustomDataset, LoveDADataset,
17-
PascalVOCDataset, RepeatDataset, build_dataset)
17+
MultiImageMixDataset, PascalVOCDataset,
18+
RepeatDataset, build_dataset)
1819

1920

2021
def test_classes():
@@ -95,6 +96,66 @@ def test_dataset_wrapper():
9596
assert repeat_dataset[27] == 7
9697
assert len(repeat_dataset) == 10 * len(dataset_a)
9798

99+
img_scale = (60, 60)
100+
pipeline = [
101+
# dict(type='Mosaic', img_scale=img_scale, pad_val=255),
102+
# need to merge mosaic
103+
dict(type='RandomFlip', prob=0.5),
104+
dict(type='Resize', img_scale=img_scale, keep_ratio=False),
105+
]
106+
107+
CustomDataset.load_annotations = MagicMock()
108+
results = []
109+
for _ in range(2):
110+
height = np.random.randint(10, 30)
111+
weight = np.random.randint(10, 30)
112+
img = np.ones((height, weight, 3))
113+
gt_semantic_seg = np.random.randint(5, size=(height, weight))
114+
results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))
115+
116+
classes = ['0', '1', '2', '3', '4']
117+
palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
118+
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
119+
dataset_a = CustomDataset(
120+
img_dir=MagicMock(),
121+
pipeline=[],
122+
test_mode=True,
123+
classes=classes,
124+
palette=palette)
125+
len_a = 2
126+
cat_ids_list_a = [
127+
np.random.randint(0, 80, num).tolist()
128+
for num in np.random.randint(1, 20, len_a)
129+
]
130+
dataset_a.data_infos = MagicMock()
131+
dataset_a.data_infos.__len__.return_value = len_a
132+
dataset_a.get_cat_ids = MagicMock(
133+
side_effect=lambda idx: cat_ids_list_a[idx])
134+
135+
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
136+
assert len(multi_image_mix_dataset) == len(dataset_a)
137+
138+
for idx in range(len_a):
139+
results_ = multi_image_mix_dataset[idx]
140+
141+
# test skip_type_keys
142+
multi_image_mix_dataset = MultiImageMixDataset(
143+
dataset_a, pipeline, skip_type_keys=('RandomFlip'))
144+
for idx in range(len_a):
145+
results_ = multi_image_mix_dataset[idx]
146+
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)
147+
148+
skip_type_keys = ('RandomFlip', 'Resize')
149+
multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
150+
for idx in range(len_a):
151+
results_ = multi_image_mix_dataset[idx]
152+
assert results_['img'].shape[:2] != img_scale
153+
154+
# test pipeline
155+
with pytest.raises(TypeError):
156+
pipeline = [['Resize']]
157+
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
158+
98159

99160
def test_custom_dataset():
100161
img_norm_cfg = dict(

tests/test_data/test_dataset_builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from torch.utils.data import (DistributedSampler, RandomSampler,
77
SequentialSampler)
88

9-
from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader,
10-
build_dataset)
9+
from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset,
10+
build_dataloader, build_dataset)
1111

1212

1313
@DATASETS.register_module()
@@ -48,6 +48,11 @@ def test_build_dataset():
4848
assert isinstance(dataset, ConcatDataset)
4949
assert len(dataset) == 10
5050

51+
cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[])
52+
dataset = build_dataset(cfg)
53+
assert isinstance(dataset, MultiImageMixDataset)
54+
assert len(dataset) == 10
55+
5156
# with ann_dir, split
5257
cfg = dict(
5358
type='CustomDataset',

tools/browse_dataset.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import mmcv
77
import numpy as np
8-
from mmcv import Config
8+
from mmcv import Config, DictAction
99

1010
from mmseg.datasets.builder import build_dataset
1111

@@ -42,6 +42,16 @@ def parse_args():
4242
type=float,
4343
default=0.5,
4444
help='the opacity of semantic map')
45+
parser.add_argument(
46+
'--cfg-options',
47+
nargs='+',
48+
action=DictAction,
49+
help='override some settings in the used config, the key-value pair '
50+
'in xxx=yyy format will be merged into config file. If the value to '
51+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
52+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
53+
'Note that the quotation marks are necessary and that no white space '
54+
'is allowed.')
4555
args = parser.parse_args()
4656
return args
4757

@@ -122,28 +132,32 @@ def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
122132
]
123133

124134

125-
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
135+
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
126136
cfg = Config.fromfile(config_path)
137+
if cfg_options is not None:
138+
cfg.merge_from_dict(cfg_options)
127139
train_data_cfg = cfg.data.train
128140
if isinstance(train_data_cfg, list):
129141
for _data_cfg in train_data_cfg:
142+
while 'dataset' in _data_cfg and _data_cfg[
143+
'type'] != 'MultiImageMixDataset':
144+
_data_cfg = _data_cfg['dataset']
130145
if 'pipeline' in _data_cfg:
131146
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
132-
elif 'dataset' in _data_cfg:
133-
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
134-
show_origin)
135147
else:
136148
raise ValueError
137-
elif 'dataset' in train_data_cfg:
138-
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
139149
else:
150+
while 'dataset' in train_data_cfg and train_data_cfg[
151+
'type'] != 'MultiImageMixDataset':
152+
train_data_cfg = train_data_cfg['dataset']
140153
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
141154
return cfg
142155

143156

144157
def main():
145158
args = parse_args()
146-
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
159+
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
160+
args.show_origin)
147161
dataset = build_dataset(cfg.data.train)
148162
progress_bar = mmcv.ProgressBar(len(dataset))
149163
for item in dataset:

0 commit comments

Comments
 (0)