|
1 | 1 | from __future__ import division |
2 | 2 |
|
3 | | -import logging |
4 | 3 | import random |
5 | 4 | from collections import OrderedDict |
6 | 5 |
|
|
9 | 8 | from mmcv.runner import Runner, DistSamplerSeedHook |
10 | 9 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
11 | 10 |
|
12 | | -from mmdet import __version__ |
13 | | -from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook, |
| 11 | +from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook, |
14 | 12 | CocoDistEvalmAPHook) |
15 | 13 | from mmdet.datasets import build_dataloader |
16 | 14 | from mmdet.models import RPN |
| 15 | +from .env import get_root_logger |
17 | 16 |
|
18 | 17 |
|
19 | 18 | def parse_losses(losses): |
@@ -46,72 +45,79 @@ def batch_processor(model, data, train_mode): |
46 | 45 | return outputs |
47 | 46 |
|
48 | 47 |
|
49 | | -def get_logger(log_level): |
50 | | - logging.basicConfig( |
51 | | - format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) |
52 | | - logger = logging.getLogger() |
53 | | - return logger |
54 | | - |
55 | | - |
56 | 48 | def set_random_seed(seed): |
57 | 49 | random.seed(seed) |
58 | 50 | np.random.seed(seed) |
59 | 51 | torch.manual_seed(seed) |
60 | 52 | torch.cuda.manual_seed_all(seed) |
61 | 53 |
|
62 | 54 |
|
63 | | -def train_detector(model, dataset, cfg): |
64 | | - # save mmdet version in checkpoint as meta data |
65 | | - cfg.checkpoint_config.meta = dict( |
66 | | - mmdet_version=__version__, config=cfg.text) |
67 | | - |
68 | | - logger = get_logger(cfg.log_level) |
69 | | - |
70 | | - # set random seed if specified |
71 | | - if cfg.seed is not None: |
72 | | - logger.info('Set random seed to {}'.format(cfg.seed)) |
73 | | - set_random_seed(cfg.seed) |
| 55 | +def train_detector(model, |
| 56 | + dataset, |
| 57 | + cfg, |
| 58 | + distributed=False, |
| 59 | + validate=False, |
| 60 | + logger=None): |
| 61 | + if logger is None: |
| 62 | + logger = get_root_logger(cfg.log_level) |
74 | 63 |
|
75 | | - # init distributed environment if necessary |
76 | | - if cfg.launcher == 'none': |
77 | | - dist = False |
78 | | - logger.info('Non-distributed training.') |
| 64 | + # start training |
| 65 | + if distributed: |
| 66 | + _dist_train(model, dataset, cfg, validate=validate) |
79 | 67 | else: |
80 | | - dist = True |
81 | | - init_dist(cfg.launcher, **cfg.dist_params) |
82 | | - if torch.distributed.get_rank() != 0: |
83 | | - logger.setLevel('ERROR') |
84 | | - logger.info('Distributed training.') |
| 68 | + _non_dist_train(model, dataset, cfg, validate=validate) |
85 | 69 |
|
| 70 | + |
| 71 | +def _dist_train(model, dataset, cfg, validate=False): |
86 | 72 | # prepare data loaders |
87 | 73 | data_loaders = [ |
88 | | - build_dataloader(dataset, cfg.data.imgs_per_gpu, |
89 | | - cfg.data.workers_per_gpu, cfg.gpus, dist) |
| 74 | + build_dataloader( |
| 75 | + dataset, |
| 76 | + cfg.data.imgs_per_gpu, |
| 77 | + cfg.data.workers_per_gpu, |
| 78 | + dist=True) |
90 | 79 | ] |
91 | | - |
92 | 80 | # put model on gpus |
93 | | - if dist: |
94 | | - model = MMDistributedDataParallel(model.cuda()) |
95 | | - else: |
96 | | - model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() |
97 | | - |
| 81 | + model = MMDistributedDataParallel(model.cuda()) |
98 | 82 | # build runner |
99 | 83 | runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, |
100 | 84 | cfg.log_level) |
101 | | - |
102 | 85 | # register hooks |
103 | | - optimizer_config = DistOptimizerHook( |
104 | | - **cfg.optimizer_config) if dist else cfg.optimizer_config |
| 86 | + optimizer_config = DistOptimizerHook(**cfg.optimizer_config) |
105 | 87 | runner.register_training_hooks(cfg.lr_config, optimizer_config, |
106 | 88 | cfg.checkpoint_config, cfg.log_config) |
107 | | - if dist: |
108 | | - runner.register_hook(DistSamplerSeedHook()) |
109 | | - # register eval hooks |
110 | | - if cfg.validate: |
111 | | - if isinstance(model.module, RPN): |
112 | | - runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) |
113 | | - elif cfg.data.val.type == 'CocoDataset': |
114 | | - runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) |
| 89 | + runner.register_hook(DistSamplerSeedHook()) |
| 90 | + # register eval hooks |
| 91 | + if validate: |
| 92 | + if isinstance(model.module, RPN): |
| 93 | + runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) |
| 94 | + elif cfg.data.val.type == 'CocoDataset': |
| 95 | + runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) |
| 96 | + |
| 97 | + if cfg.resume_from: |
| 98 | + runner.resume(cfg.resume_from) |
| 99 | + elif cfg.load_from: |
| 100 | + runner.load_checkpoint(cfg.load_from) |
| 101 | + runner.run(data_loaders, cfg.workflow, cfg.total_epochs) |
| 102 | + |
| 103 | + |
| 104 | +def _non_dist_train(model, dataset, cfg, validate=False): |
| 105 | + # prepare data loaders |
| 106 | + data_loaders = [ |
| 107 | + build_dataloader( |
| 108 | + dataset, |
| 109 | + cfg.data.imgs_per_gpu, |
| 110 | + cfg.data.workers_per_gpu, |
| 111 | + cfg.gpus, |
| 112 | + dist=False) |
| 113 | + ] |
| 114 | + # put model on gpus |
| 115 | + model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() |
| 116 | + # build runner |
| 117 | + runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, |
| 118 | + cfg.log_level) |
| 119 | + runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, |
| 120 | + cfg.checkpoint_config, cfg.log_config) |
115 | 121 |
|
116 | 122 | if cfg.resume_from: |
117 | 123 | runner.resume(cfg.resume_from) |
|
0 commit comments