Skip to content

Commit 3f79707

Browse files
authored
[Enhance] Add extra dataloader settings in configs (open-mmlab#1435)
* [Enhance] Add extra dataloader settings in configs * val default samples * val default samples * del unuse * del unused
1 parent 91b1bcb commit 3f79707

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

mmseg/apis/train.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,25 @@ def train_segmentor(model,
7979

8080
# prepare data loaders
8181
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
82-
data_loaders = [
83-
build_dataloader(
84-
ds,
85-
cfg.data.samples_per_gpu,
86-
cfg.data.workers_per_gpu,
87-
# cfg.gpus will be ignored if distributed
88-
len(cfg.gpu_ids),
89-
dist=distributed,
90-
seed=cfg.seed,
91-
drop_last=True) for ds in dataset
92-
]
82+
# The default loader config
83+
loader_cfg = dict(
84+
# cfg.gpus will be ignored if distributed
85+
num_gpus=len(cfg.gpu_ids),
86+
dist=distributed,
87+
seed=cfg.seed,
88+
drop_last=True)
89+
# The overall dataloader settings
90+
loader_cfg.update({
91+
k: v
92+
for k, v in cfg.data.items() if k not in [
93+
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
94+
'test_dataloader'
95+
]
96+
})
97+
98+
# The specific dataloader settings
99+
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
100+
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
93101

94102
# put model on gpus
95103
if distributed:
@@ -142,12 +150,14 @@ def train_segmentor(model,
142150
# register eval hooks
143151
if validate:
144152
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
145-
val_dataloader = build_dataloader(
146-
val_dataset,
147-
samples_per_gpu=1,
148-
workers_per_gpu=cfg.data.workers_per_gpu,
149-
dist=distributed,
150-
shuffle=False)
153+
# The specific dataloader settings
154+
val_loader_cfg = {
155+
**loader_cfg,
156+
'samples_per_gpu': 1,
157+
'shuffle': False, # Not shuffle by default
158+
**cfg.data.get('val_dataloader', {}),
159+
}
160+
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
151161
eval_cfg = cfg.get('evaluation', {})
152162
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
153163
eval_hook = DistEvalHook if distributed else EvalHook

tools/test.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,28 @@ def main():
191191
# build the dataloader
192192
# TODO: support multiple images per gpu (only minor changes are needed)
193193
dataset = build_dataset(cfg.data.test)
194-
data_loader = build_dataloader(
195-
dataset,
196-
samples_per_gpu=1,
197-
workers_per_gpu=cfg.data.workers_per_gpu,
194+
# The default loader config
195+
loader_cfg = dict(
196+
# cfg.gpus will be ignored if distributed
197+
num_gpus=len(cfg.gpu_ids),
198198
dist=distributed,
199199
shuffle=False)
200+
# The overall dataloader settings
201+
loader_cfg.update({
202+
k: v
203+
for k, v in cfg.data.items() if k not in [
204+
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
205+
'test_dataloader'
206+
]
207+
})
208+
test_loader_cfg = {
209+
**loader_cfg,
210+
'samples_per_gpu': 1,
211+
'shuffle': False, # Not shuffle by default
212+
**cfg.data.get('test_dataloader', {})
213+
}
214+
# build the dataloader
215+
data_loader = build_dataloader(dataset, **test_loader_cfg)
200216

201217
# build the model and load checkpoint
202218
cfg.model.train_cfg = None

0 commit comments

Comments
 (0)