|
| 1 | +from __future__ import division |
| 2 | + |
| 3 | +from collections import OrderedDict |
| 4 | + |
| 5 | +import torch |
| 6 | +from mmcv.runner import Runner, DistSamplerSeedHook |
| 7 | +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
| 8 | + |
| 9 | +from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook, |
| 10 | + CocoDistEvalmAPHook) |
| 11 | +from mmdet.datasets import build_dataloader |
| 12 | +from mmdet.models import RPN |
| 13 | +from .env import get_root_logger |
| 14 | + |
| 15 | + |
| 16 | +def parse_losses(losses): |
| 17 | + log_vars = OrderedDict() |
| 18 | + for loss_name, loss_value in losses.items(): |
| 19 | + if isinstance(loss_value, torch.Tensor): |
| 20 | + log_vars[loss_name] = loss_value.mean() |
| 21 | + elif isinstance(loss_value, list): |
| 22 | + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) |
| 23 | + else: |
| 24 | + raise TypeError( |
| 25 | + '{} is not a tensor or list of tensors'.format(loss_name)) |
| 26 | + |
| 27 | + loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) |
| 28 | + |
| 29 | + log_vars['loss'] = loss |
| 30 | + for name in log_vars: |
| 31 | + log_vars[name] = log_vars[name].item() |
| 32 | + |
| 33 | + return loss, log_vars |
| 34 | + |
| 35 | + |
| 36 | +def batch_processor(model, data, train_mode): |
| 37 | + losses = model(**data) |
| 38 | + loss, log_vars = parse_losses(losses) |
| 39 | + |
| 40 | + outputs = dict( |
| 41 | + loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) |
| 42 | + |
| 43 | + return outputs |
| 44 | + |
| 45 | + |
| 46 | +def train_detector(model, |
| 47 | + dataset, |
| 48 | + cfg, |
| 49 | + distributed=False, |
| 50 | + validate=False, |
| 51 | + logger=None): |
| 52 | + if logger is None: |
| 53 | + logger = get_root_logger(cfg.log_level) |
| 54 | + |
| 55 | + # start training |
| 56 | + if distributed: |
| 57 | + _dist_train(model, dataset, cfg, validate=validate) |
| 58 | + else: |
| 59 | + _non_dist_train(model, dataset, cfg, validate=validate) |
| 60 | + |
| 61 | + |
| 62 | +def _dist_train(model, dataset, cfg, validate=False): |
| 63 | + # prepare data loaders |
| 64 | + data_loaders = [ |
| 65 | + build_dataloader( |
| 66 | + dataset, |
| 67 | + cfg.data.imgs_per_gpu, |
| 68 | + cfg.data.workers_per_gpu, |
| 69 | + dist=True) |
| 70 | + ] |
| 71 | + # put model on gpus |
| 72 | + model = MMDistributedDataParallel(model.cuda()) |
| 73 | + # build runner |
| 74 | + runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, |
| 75 | + cfg.log_level) |
| 76 | + # register hooks |
| 77 | + optimizer_config = DistOptimizerHook(**cfg.optimizer_config) |
| 78 | + runner.register_training_hooks(cfg.lr_config, optimizer_config, |
| 79 | + cfg.checkpoint_config, cfg.log_config) |
| 80 | + runner.register_hook(DistSamplerSeedHook()) |
| 81 | + # register eval hooks |
| 82 | + if validate: |
| 83 | + if isinstance(model.module, RPN): |
| 84 | + runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) |
| 85 | + elif cfg.data.val.type == 'CocoDataset': |
| 86 | + runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) |
| 87 | + |
| 88 | + if cfg.resume_from: |
| 89 | + runner.resume(cfg.resume_from) |
| 90 | + elif cfg.load_from: |
| 91 | + runner.load_checkpoint(cfg.load_from) |
| 92 | + runner.run(data_loaders, cfg.workflow, cfg.total_epochs) |
| 93 | + |
| 94 | + |
| 95 | +def _non_dist_train(model, dataset, cfg, validate=False): |
| 96 | + # prepare data loaders |
| 97 | + data_loaders = [ |
| 98 | + build_dataloader( |
| 99 | + dataset, |
| 100 | + cfg.data.imgs_per_gpu, |
| 101 | + cfg.data.workers_per_gpu, |
| 102 | + cfg.gpus, |
| 103 | + dist=False) |
| 104 | + ] |
| 105 | + # put model on gpus |
| 106 | + model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() |
| 107 | + # build runner |
| 108 | + runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, |
| 109 | + cfg.log_level) |
| 110 | + runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, |
| 111 | + cfg.checkpoint_config, cfg.log_config) |
| 112 | + |
| 113 | + if cfg.resume_from: |
| 114 | + runner.resume(cfg.resume_from) |
| 115 | + elif cfg.load_from: |
| 116 | + runner.load_checkpoint(cfg.load_from) |
| 117 | + runner.run(data_loaders, cfg.workflow, cfg.total_epochs) |
0 commit comments