@@ -79,17 +79,25 @@ def train_segmentor(model,
7979
8080 # prepare data loaders
8181 dataset = dataset if isinstance (dataset , (list , tuple )) else [dataset ]
82- data_loaders = [
83- build_dataloader (
84- ds ,
85- cfg .data .samples_per_gpu ,
86- cfg .data .workers_per_gpu ,
87- # cfg.gpus will be ignored if distributed
88- len (cfg .gpu_ids ),
89- dist = distributed ,
90- seed = cfg .seed ,
91- drop_last = True ) for ds in dataset
92- ]
82+ # The default loader config
83+ loader_cfg = dict (
84+ # cfg.gpus will be ignored if distributed
85+ num_gpus = len (cfg .gpu_ids ),
86+ dist = distributed ,
87+ seed = cfg .seed ,
88+ drop_last = True )
89+ # The overall dataloader settings
90+ loader_cfg .update ({
91+ k : v
92+ for k , v in cfg .data .items () if k not in [
93+ 'train' , 'val' , 'test' , 'train_dataloader' , 'val_dataloader' ,
94+ 'test_dataloader'
95+ ]
96+ })
97+
98+ # The specific dataloader settings
99+ train_loader_cfg = {** loader_cfg , ** cfg .data .get ('train_dataloader' , {})}
100+ data_loaders = [build_dataloader (ds , ** train_loader_cfg ) for ds in dataset ]
93101
94102 # put model on gpus
95103 if distributed :
@@ -142,12 +150,14 @@ def train_segmentor(model,
142150 # register eval hooks
143151 if validate :
144152 val_dataset = build_dataset (cfg .data .val , dict (test_mode = True ))
145- val_dataloader = build_dataloader (
146- val_dataset ,
147- samples_per_gpu = 1 ,
148- workers_per_gpu = cfg .data .workers_per_gpu ,
149- dist = distributed ,
150- shuffle = False )
153+ # The specific dataloader settings
154+ val_loader_cfg = {
155+ ** loader_cfg ,
156+ 'samples_per_gpu' : 1 ,
157+ 'shuffle' : False , # Not shuffle by default
158+ ** cfg .data .get ('val_dataloader' , {}),
159+ }
160+ val_dataloader = build_dataloader (val_dataset , ** val_loader_cfg )
151161 eval_cfg = cfg .get ('evaluation' , {})
152162 eval_cfg ['by_epoch' ] = cfg .runner ['type' ] != 'IterBasedRunner'
153163 eval_hook = DistEvalHook if distributed else EvalHook
0 commit comments