@@ -43,33 +43,19 @@ def train_model(model,
4343
4444 # prepare data loaders
4545 dataset = dataset if isinstance (dataset , (list , tuple )) else [dataset ]
46- # step 1: give default values and override (if exist) from cfg.data
47- loader_cfg = {
48- ** dict (
49- seed = cfg .get ('seed' ),
50- drop_last = False ,
51- dist = distributed ,
52- num_gpus = len (cfg .gpu_ids )),
53- ** ({} if torch .__version__ != 'parrots' else dict (
54- prefetch_num = 2 ,
55- pin_memory = False ,
56- )),
57- ** dict ((k , cfg .data [k ]) for k in [
58- 'samples_per_gpu' ,
59- 'workers_per_gpu' ,
60- 'shuffle' ,
61- 'seed' ,
62- 'drop_last' ,
63- 'prefetch_num' ,
64- 'pin_memory' ,
65- 'persistent_workers' ,
66- ] if k in cfg .data )
67- }
68-
69- # step 2: cfg.data.train_dataloader has highest priority
70- train_loader_cfg = dict (loader_cfg , ** cfg .data .get ('train_dataloader' , {}))
71-
72- data_loaders = [build_dataloader (ds , ** train_loader_cfg ) for ds in dataset ]
46+ dataloader_setting = dict (
47+ samples_per_gpu = cfg .data .get ('samples_per_gpu' , {}),
48+ workers_per_gpu = cfg .data .get ('workers_per_gpu' , {}),
49+ # cfg.gpus will be ignored if distributed
50+ num_gpus = len (cfg .gpu_ids ),
51+ dist = distributed ,
52+ seed = cfg .seed )
53+ dataloader_setting = dict (dataloader_setting ,
54+ ** cfg .data .get ('train_dataloader' , {}))
55+
56+ data_loaders = [
57+ build_dataloader (ds , ** dataloader_setting ) for ds in dataset
58+ ]
7359
7460 # determine wether use adversarial training precess or not
7561 use_adverserial_train = cfg .get ('use_adversarial_train' , False )
0 commit comments