@@ -79,17 +79,25 @@ def train_segmentor(model,
79
79
80
80
# prepare data loaders
81
81
dataset = dataset if isinstance (dataset , (list , tuple )) else [dataset ]
82
- data_loaders = [
83
- build_dataloader (
84
- ds ,
85
- cfg .data .samples_per_gpu ,
86
- cfg .data .workers_per_gpu ,
87
- # cfg.gpus will be ignored if distributed
88
- len (cfg .gpu_ids ),
89
- dist = distributed ,
90
- seed = cfg .seed ,
91
- drop_last = True ) for ds in dataset
92
- ]
82
+ # The default loader config
83
+ loader_cfg = dict (
84
+ # cfg.gpus will be ignored if distributed
85
+ num_gpus = len (cfg .gpu_ids ),
86
+ dist = distributed ,
87
+ seed = cfg .seed ,
88
+ drop_last = True )
89
+ # The overall dataloader settings
90
+ loader_cfg .update ({
91
+ k : v
92
+ for k , v in cfg .data .items () if k not in [
93
+ 'train' , 'val' , 'test' , 'train_dataloader' , 'val_dataloader' ,
94
+ 'test_dataloader'
95
+ ]
96
+ })
97
+
98
+ # The specific dataloader settings
99
+ train_loader_cfg = {** loader_cfg , ** cfg .data .get ('train_dataloader' , {})}
100
+ data_loaders = [build_dataloader (ds , ** train_loader_cfg ) for ds in dataset ]
93
101
94
102
# put model on gpus
95
103
if distributed :
@@ -142,12 +150,14 @@ def train_segmentor(model,
142
150
# register eval hooks
143
151
if validate :
144
152
val_dataset = build_dataset (cfg .data .val , dict (test_mode = True ))
145
- val_dataloader = build_dataloader (
146
- val_dataset ,
147
- samples_per_gpu = 1 ,
148
- workers_per_gpu = cfg .data .workers_per_gpu ,
149
- dist = distributed ,
150
- shuffle = False )
153
+ # The specific dataloader settings
154
+ val_loader_cfg = {
155
+ ** loader_cfg ,
156
+ 'samples_per_gpu' : 1 ,
157
+ 'shuffle' : False , # Not shuffle by default
158
+ ** cfg .data .get ('val_dataloader' , {}),
159
+ }
160
+ val_dataloader = build_dataloader (val_dataset , ** val_loader_cfg )
151
161
eval_cfg = cfg .get ('evaluation' , {})
152
162
eval_cfg ['by_epoch' ] = cfg .runner ['type' ] != 'IterBasedRunner'
153
163
eval_hook = DistEvalHook if distributed else EvalHook
0 commit comments