Skip to content

Commit c76a95e

Browse files
committed
reorganize the training api
1 parent 2a43cc7 commit c76a95e

File tree

7 files changed

+161
-97
lines changed

7 files changed

+161
-97
lines changed

mmdet/api/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from .env import init_dist, get_root_logger, set_random_seed
12
from .train import train_detector
23
from .inference import inference_detector
34

4-
__all__ = ['train_detector', 'inference_detector']
5+
__all__ = [
6+
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
7+
'inference_detector'
8+
]

mmdet/api/env.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
import os
3+
import random
4+
5+
import numpy as np
6+
import torch
7+
import torch.distributed as dist
8+
import torch.multiprocessing as mp
9+
from mmcv.runner import get_dist_info
10+
11+
12+
def init_dist(launcher, backend='nccl', **kwargs):
13+
if mp.get_start_method(allow_none=True) is None:
14+
mp.set_start_method('spawn')
15+
if launcher == 'pytorch':
16+
_init_dist_pytorch(backend, **kwargs)
17+
elif launcher == 'mpi':
18+
_init_dist_mpi(backend, **kwargs)
19+
elif launcher == 'slurm':
20+
_init_dist_slurm(backend, **kwargs)
21+
else:
22+
raise ValueError('Invalid launcher type: {}'.format(launcher))
23+
24+
25+
def _init_dist_pytorch(backend, **kwargs):
26+
# TODO: use local_rank instead of rank % num_gpus
27+
rank = int(os.environ['RANK'])
28+
num_gpus = torch.cuda.device_count()
29+
torch.cuda.set_device(rank % num_gpus)
30+
dist.init_process_group(backend=backend, **kwargs)
31+
32+
33+
def _init_dist_mpi(backend, **kwargs):
34+
raise NotImplementedError
35+
36+
37+
def _init_dist_slurm(backend, **kwargs):
38+
raise NotImplementedError
39+
40+
41+
def set_random_seed(seed):
42+
random.seed(seed)
43+
np.random.seed(seed)
44+
torch.manual_seed(seed)
45+
torch.cuda.manual_seed_all(seed)
46+
47+
48+
def get_root_logger(log_level=logging.INFO):
49+
logger = logging.getLogger()
50+
if not logger.hasHandlers():
51+
logging.basicConfig(
52+
format='%(asctime)s - %(levelname)s - %(message)s',
53+
level=log_level)
54+
rank, _ = get_dist_info()
55+
if rank != 0:
56+
logger.setLevel('ERROR')
57+
return logger

mmdet/api/train.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import division
22

3-
import logging
43
import random
54
from collections import OrderedDict
65

@@ -9,11 +8,11 @@
98
from mmcv.runner import Runner, DistSamplerSeedHook
109
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
1110

12-
from mmdet import __version__
13-
from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook,
11+
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
1412
CocoDistEvalmAPHook)
1513
from mmdet.datasets import build_dataloader
1614
from mmdet.models import RPN
15+
from .env import get_root_logger
1716

1817

1918
def parse_losses(losses):
@@ -46,72 +45,79 @@ def batch_processor(model, data, train_mode):
4645
return outputs
4746

4847

49-
def get_logger(log_level):
50-
logging.basicConfig(
51-
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
52-
logger = logging.getLogger()
53-
return logger
54-
55-
5648
def set_random_seed(seed):
5749
random.seed(seed)
5850
np.random.seed(seed)
5951
torch.manual_seed(seed)
6052
torch.cuda.manual_seed_all(seed)
6153

6254

63-
def train_detector(model, dataset, cfg):
64-
# save mmdet version in checkpoint as meta data
65-
cfg.checkpoint_config.meta = dict(
66-
mmdet_version=__version__, config=cfg.text)
67-
68-
logger = get_logger(cfg.log_level)
69-
70-
# set random seed if specified
71-
if cfg.seed is not None:
72-
logger.info('Set random seed to {}'.format(cfg.seed))
73-
set_random_seed(cfg.seed)
55+
def train_detector(model,
56+
dataset,
57+
cfg,
58+
distributed=False,
59+
validate=False,
60+
logger=None):
61+
if logger is None:
62+
logger = get_root_logger(cfg.log_level)
7463

75-
# init distributed environment if necessary
76-
if cfg.launcher == 'none':
77-
dist = False
78-
logger.info('Non-distributed training.')
64+
# start training
65+
if distributed:
66+
_dist_train(model, dataset, cfg, validate=validate)
7967
else:
80-
dist = True
81-
init_dist(cfg.launcher, **cfg.dist_params)
82-
if torch.distributed.get_rank() != 0:
83-
logger.setLevel('ERROR')
84-
logger.info('Distributed training.')
68+
_non_dist_train(model, dataset, cfg, validate=validate)
8569

70+
71+
def _dist_train(model, dataset, cfg, validate=False):
8672
# prepare data loaders
8773
data_loaders = [
88-
build_dataloader(dataset, cfg.data.imgs_per_gpu,
89-
cfg.data.workers_per_gpu, cfg.gpus, dist)
74+
build_dataloader(
75+
dataset,
76+
cfg.data.imgs_per_gpu,
77+
cfg.data.workers_per_gpu,
78+
dist=True)
9079
]
91-
9280
# put model on gpus
93-
if dist:
94-
model = MMDistributedDataParallel(model.cuda())
95-
else:
96-
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
97-
81+
model = MMDistributedDataParallel(model.cuda())
9882
# build runner
9983
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
10084
cfg.log_level)
101-
10285
# register hooks
103-
optimizer_config = DistOptimizerHook(
104-
**cfg.optimizer_config) if dist else cfg.optimizer_config
86+
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
10587
runner.register_training_hooks(cfg.lr_config, optimizer_config,
10688
cfg.checkpoint_config, cfg.log_config)
107-
if dist:
108-
runner.register_hook(DistSamplerSeedHook())
109-
# register eval hooks
110-
if cfg.validate:
111-
if isinstance(model.module, RPN):
112-
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
113-
elif cfg.data.val.type == 'CocoDataset':
114-
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
89+
runner.register_hook(DistSamplerSeedHook())
90+
# register eval hooks
91+
if validate:
92+
if isinstance(model.module, RPN):
93+
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
94+
elif cfg.data.val.type == 'CocoDataset':
95+
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
96+
97+
if cfg.resume_from:
98+
runner.resume(cfg.resume_from)
99+
elif cfg.load_from:
100+
runner.load_checkpoint(cfg.load_from)
101+
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
102+
103+
104+
def _non_dist_train(model, dataset, cfg, validate=False):
105+
# prepare data loaders
106+
data_loaders = [
107+
build_dataloader(
108+
dataset,
109+
cfg.data.imgs_per_gpu,
110+
cfg.data.workers_per_gpu,
111+
cfg.gpus,
112+
dist=False)
113+
]
114+
# put model on gpus
115+
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
116+
# build runner
117+
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
118+
cfg.log_level)
119+
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
120+
cfg.checkpoint_config, cfg.log_config)
115121

116122
if cfg.resume_from:
117123
runner.resume(cfg.resume_from)

mmdet/core/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook
1+
from .dist_utils import allreduce_grads, DistOptimizerHook
22
from .misc import tensor2imgs, unmap, multi_apply
33

44
__all__ = [
5-
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs',
6-
'unmap', 'multi_apply'
5+
'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap',
6+
'multi_apply'
77
]

mmdet/core/utils/dist_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,11 @@
1-
import os
21
from collections import OrderedDict
32

4-
import torch
5-
import torch.multiprocessing as mp
63
import torch.distributed as dist
74
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
85
_take_tensors)
96
from mmcv.runner import OptimizerHook
107

118

12-
def init_dist(launcher, backend='nccl', **kwargs):
13-
if mp.get_start_method(allow_none=True) is None:
14-
mp.set_start_method('spawn')
15-
if launcher == 'pytorch':
16-
_init_dist_pytorch(backend, **kwargs)
17-
elif launcher == 'mpi':
18-
_init_dist_mpi(backend, **kwargs)
19-
elif launcher == 'slurm':
20-
_init_dist_slurm(backend, **kwargs)
21-
else:
22-
raise ValueError('Invalid launcher type: {}'.format(launcher))
23-
24-
25-
def _init_dist_pytorch(backend, **kwargs):
26-
# TODO: use local_rank instead of rank % num_gpus
27-
rank = int(os.environ['RANK'])
28-
num_gpus = torch.cuda.device_count()
29-
torch.cuda.set_device(rank % num_gpus)
30-
dist.init_process_group(backend=backend, **kwargs)
31-
32-
33-
def _init_dist_mpi(backend, **kwargs):
34-
raise NotImplementedError
35-
36-
37-
def _init_dist_slurm(backend, **kwargs):
38-
raise NotImplementedError
39-
40-
419
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
4210
if bucket_size_mb > 0:
4311
bucket_size_bytes = bucket_size_mb * 1024 * 1024

mmdet/datasets/loader/build_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def build_dataloader(dataset,
1616
imgs_per_gpu,
1717
workers_per_gpu,
18-
num_gpus,
18+
num_gpus=1,
1919
dist=True,
2020
**kwargs):
2121
if dist:

tools/train.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from mmcv import Config
55
from mmcv.runner import obj_from_dict
66

7-
from mmdet import datasets
8-
from mmdet.api import train_detector
7+
from mmdet import datasets, __version__
8+
from mmdet.api import (train_detector, init_dist, get_root_logger,
9+
set_random_seed)
910
from mmdet.models import build_detector
1011

1112

@@ -16,10 +17,14 @@ def parse_args():
1617
parser.add_argument(
1718
'--validate',
1819
action='store_true',
19-
help='whether to add a validate phase')
20+
help='whether to evaluate the checkpoint during training')
2021
parser.add_argument(
21-
'--gpus', type=int, default=1, help='number of gpus to use')
22-
parser.add_argument('--seed', type=int, help='random seed')
22+
'--gpus',
23+
type=int,
24+
default=1,
25+
help='number of gpus to use '
26+
'(only applicable to non-distributed training)')
27+
parser.add_argument('--seed', type=int, default=None, help='random seed')
2328
parser.add_argument(
2429
'--launcher',
2530
choices=['none', 'pytorch', 'slurm', 'mpi'],
@@ -33,19 +38,43 @@ def parse_args():
3338

3439
def main():
3540
args = parse_args()
41+
3642
cfg = Config.fromfile(args.config)
43+
# update configs according to CLI args
3744
if args.work_dir is not None:
3845
cfg.work_dir = args.work_dir
39-
cfg.validate = args.validate
4046
cfg.gpus = args.gpus
41-
cfg.seed = args.seed
42-
cfg.launcher = args.launcher
43-
cfg.local_rank = args.local_rank
44-
# build model
47+
if cfg.checkpoint_config is not None:
48+
# save mmdet version in checkpoints as meta data
49+
cfg.checkpoint_config.meta = dict(
50+
mmdet_version=__version__, config=cfg.text)
51+
52+
# init distributed env first, since logger depends on the dist info.
53+
if args.launcher == 'none':
54+
distributed = False
55+
else:
56+
distributed = True
57+
init_dist(args.launcher, **cfg.dist_params)
58+
59+
# init logger before other steps
60+
logger = get_root_logger(cfg.log_level)
61+
logger.info('Distributed training: {}'.format(distributed))
62+
63+
# set random seeds
64+
if args.seed is not None:
65+
logger.info('Set random seed to {}'.format(args.seed))
66+
set_random_seed(args.seed)
67+
4568
model = build_detector(
4669
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
4770
train_dataset = obj_from_dict(cfg.data.train, datasets)
48-
train_detector(model, train_dataset, cfg)
71+
train_detector(
72+
model,
73+
train_dataset,
74+
cfg,
75+
distributed=distributed,
76+
validate=args.validate,
77+
logger=logger)
4978

5079

5180
if __name__ == '__main__':

0 commit comments

Comments
 (0)