Skip to content

Commit 6c4da82

Browse files
authored
Revert "mmocr-style cfg (open-mmlab#926)" (open-mmlab#927)
This reverts commit 7fdaebb.
1 parent 7fdaebb commit 6c4da82

File tree

2 files changed

+22
-50
lines changed

2 files changed

+22
-50
lines changed

mmpose/apis/train.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,33 +43,19 @@ def train_model(model,
4343

4444
# prepare data loaders
4545
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
46-
# step 1: give default values and override (if exist) from cfg.data
47-
loader_cfg = {
48-
**dict(
49-
seed=cfg.get('seed'),
50-
drop_last=False,
51-
dist=distributed,
52-
num_gpus=len(cfg.gpu_ids)),
53-
**({} if torch.__version__ != 'parrots' else dict(
54-
prefetch_num=2,
55-
pin_memory=False,
56-
)),
57-
**dict((k, cfg.data[k]) for k in [
58-
'samples_per_gpu',
59-
'workers_per_gpu',
60-
'shuffle',
61-
'seed',
62-
'drop_last',
63-
'prefetch_num',
64-
'pin_memory',
65-
'persistent_workers',
66-
] if k in cfg.data)
67-
}
68-
69-
# step 2: cfg.data.train_dataloader has highest priority
70-
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
71-
72-
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
46+
dataloader_setting = dict(
47+
samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
48+
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
49+
# cfg.gpus will be ignored if distributed
50+
num_gpus=len(cfg.gpu_ids),
51+
dist=distributed,
52+
seed=cfg.seed)
53+
dataloader_setting = dict(dataloader_setting,
54+
**cfg.data.get('train_dataloader', {}))
55+
56+
data_loaders = [
57+
build_dataloader(ds, **dataloader_setting) for ds in dataset
58+
]
7359

7460
# determine wether use adversarial training precess or not
7561
use_adverserial_train = cfg.get('use_adversarial_train', False)

tools/test.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,29 +111,15 @@ def main():
111111

112112
# build the dataloader
113113
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
114-
# step 1: give default values and override (if exist) from cfg.data
115-
loader_cfg = {
116-
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
117-
**({} if torch.__version__ != 'parrots' else dict(
118-
prefetch_num=2,
119-
pin_memory=False,
120-
)),
121-
**dict((k, cfg.data[k]) for k in [
122-
'seed',
123-
'prefetch_num',
124-
'pin_memory',
125-
'persistent_workers',
126-
] if k in cfg.data)
127-
}
128-
# step2: cfg.data.test_dataloader has higher priority
129-
test_loader_cfg = {
130-
**loader_cfg,
131-
**dict(shuffle=False, drop_last=False),
132-
**cfg.data.get('workers_per_gpu', 1),
133-
**cfg.data.get('test_dataloader', {}),
134-
**dict(samples_per_gpu=1)
135-
}
136-
data_loader = build_dataloader(dataset, **test_loader_cfg)
114+
dataloader_setting = dict(
115+
samples_per_gpu=1,
116+
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
117+
dist=distributed,
118+
shuffle=False,
119+
drop_last=False)
120+
dataloader_setting = dict(dataloader_setting,
121+
**cfg.data.get('test_dataloader', {}))
122+
data_loader = build_dataloader(dataset, **dataloader_setting)
137123

138124
# build the model and load checkpoint
139125
model = build_posenet(cfg.model)

0 commit comments

Comments
 (0)