Skip to content

Commit df2aab9

Browse files
authored
Merge pull request open-mmlab#16 from myownskyW7/dev
add high level api
2 parents 3c51dcc + 724abbc commit df2aab9

File tree

8 files changed

+275
-147
lines changed

8 files changed

+275
-147
lines changed

mmdet/api/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .env import init_dist, get_root_logger, set_random_seed
2+
from .train import train_detector
3+
from .inference import inference_detector
4+
5+
__all__ = [
6+
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
7+
'inference_detector'
8+
]

mmdet/api/env.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
import os
3+
import random
4+
5+
import numpy as np
6+
import torch
7+
import torch.distributed as dist
8+
import torch.multiprocessing as mp
9+
from mmcv.runner import get_dist_info
10+
11+
12+
def init_dist(launcher, backend='nccl', **kwargs):
13+
if mp.get_start_method(allow_none=True) is None:
14+
mp.set_start_method('spawn')
15+
if launcher == 'pytorch':
16+
_init_dist_pytorch(backend, **kwargs)
17+
elif launcher == 'mpi':
18+
_init_dist_mpi(backend, **kwargs)
19+
elif launcher == 'slurm':
20+
_init_dist_slurm(backend, **kwargs)
21+
else:
22+
raise ValueError('Invalid launcher type: {}'.format(launcher))
23+
24+
25+
def _init_dist_pytorch(backend, **kwargs):
26+
# TODO: use local_rank instead of rank % num_gpus
27+
rank = int(os.environ['RANK'])
28+
num_gpus = torch.cuda.device_count()
29+
torch.cuda.set_device(rank % num_gpus)
30+
dist.init_process_group(backend=backend, **kwargs)
31+
32+
33+
def _init_dist_mpi(backend, **kwargs):
34+
raise NotImplementedError
35+
36+
37+
def _init_dist_slurm(backend, **kwargs):
38+
raise NotImplementedError
39+
40+
41+
def set_random_seed(seed):
42+
random.seed(seed)
43+
np.random.seed(seed)
44+
torch.manual_seed(seed)
45+
torch.cuda.manual_seed_all(seed)
46+
47+
48+
def get_root_logger(log_level=logging.INFO):
49+
logger = logging.getLogger()
50+
if not logger.hasHandlers():
51+
logging.basicConfig(
52+
format='%(asctime)s - %(levelname)s - %(message)s',
53+
level=log_level)
54+
rank, _ = get_dist_info()
55+
if rank != 0:
56+
logger.setLevel('ERROR')
57+
return logger

mmdet/api/inference.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import mmcv
2+
import numpy as np
3+
import torch
4+
5+
from mmdet.datasets import to_tensor
6+
from mmdet.datasets.transforms import ImageTransform
7+
from mmdet.core import get_classes
8+
9+
10+
def _prepare_data(img, img_transform, cfg, device):
11+
ori_shape = img.shape
12+
img, img_shape, pad_shape, scale_factor = img_transform(
13+
img, scale=cfg.data.test.img_scale)
14+
img = to_tensor(img).to(device).unsqueeze(0)
15+
img_meta = [
16+
dict(
17+
ori_shape=ori_shape,
18+
img_shape=img_shape,
19+
pad_shape=pad_shape,
20+
scale_factor=scale_factor,
21+
flip=False)
22+
]
23+
return dict(img=[img], img_meta=[img_meta])
24+
25+
26+
def inference_detector(model, imgs, cfg, device='cuda:0'):
27+
28+
imgs = imgs if isinstance(imgs, list) else [imgs]
29+
img_transform = ImageTransform(
30+
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
31+
model = model.to(device)
32+
model.eval()
33+
for img in imgs:
34+
img = mmcv.imread(img)
35+
data = _prepare_data(img, img_transform, cfg, device)
36+
with torch.no_grad():
37+
result = model(return_loss=False, rescale=True, **data)
38+
yield result
39+
40+
41+
def show_result(img, result, dataset='coco', score_thr=0.3):
42+
class_names = get_classes(dataset)
43+
labels = [
44+
np.full(bbox.shape[0], i, dtype=np.int32)
45+
for i, bbox in enumerate(result)
46+
]
47+
labels = np.concatenate(labels)
48+
bboxes = np.vstack(result)
49+
mmcv.imshow_det_bboxes(
50+
img.copy(),
51+
bboxes,
52+
labels,
53+
class_names=class_names,
54+
score_thr=score_thr)

mmdet/api/train.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from __future__ import division
2+
3+
from collections import OrderedDict
4+
5+
import torch
6+
from mmcv.runner import Runner, DistSamplerSeedHook
7+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
8+
9+
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
10+
CocoDistEvalmAPHook)
11+
from mmdet.datasets import build_dataloader
12+
from mmdet.models import RPN
13+
from .env import get_root_logger
14+
15+
16+
def parse_losses(losses):
17+
log_vars = OrderedDict()
18+
for loss_name, loss_value in losses.items():
19+
if isinstance(loss_value, torch.Tensor):
20+
log_vars[loss_name] = loss_value.mean()
21+
elif isinstance(loss_value, list):
22+
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
23+
else:
24+
raise TypeError(
25+
'{} is not a tensor or list of tensors'.format(loss_name))
26+
27+
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
28+
29+
log_vars['loss'] = loss
30+
for name in log_vars:
31+
log_vars[name] = log_vars[name].item()
32+
33+
return loss, log_vars
34+
35+
36+
def batch_processor(model, data, train_mode):
37+
losses = model(**data)
38+
loss, log_vars = parse_losses(losses)
39+
40+
outputs = dict(
41+
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
42+
43+
return outputs
44+
45+
46+
def train_detector(model,
47+
dataset,
48+
cfg,
49+
distributed=False,
50+
validate=False,
51+
logger=None):
52+
if logger is None:
53+
logger = get_root_logger(cfg.log_level)
54+
55+
# start training
56+
if distributed:
57+
_dist_train(model, dataset, cfg, validate=validate)
58+
else:
59+
_non_dist_train(model, dataset, cfg, validate=validate)
60+
61+
62+
def _dist_train(model, dataset, cfg, validate=False):
63+
# prepare data loaders
64+
data_loaders = [
65+
build_dataloader(
66+
dataset,
67+
cfg.data.imgs_per_gpu,
68+
cfg.data.workers_per_gpu,
69+
dist=True)
70+
]
71+
# put model on gpus
72+
model = MMDistributedDataParallel(model.cuda())
73+
# build runner
74+
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
75+
cfg.log_level)
76+
# register hooks
77+
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
78+
runner.register_training_hooks(cfg.lr_config, optimizer_config,
79+
cfg.checkpoint_config, cfg.log_config)
80+
runner.register_hook(DistSamplerSeedHook())
81+
# register eval hooks
82+
if validate:
83+
if isinstance(model.module, RPN):
84+
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
85+
elif cfg.data.val.type == 'CocoDataset':
86+
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
87+
88+
if cfg.resume_from:
89+
runner.resume(cfg.resume_from)
90+
elif cfg.load_from:
91+
runner.load_checkpoint(cfg.load_from)
92+
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
93+
94+
95+
def _non_dist_train(model, dataset, cfg, validate=False):
96+
# prepare data loaders
97+
data_loaders = [
98+
build_dataloader(
99+
dataset,
100+
cfg.data.imgs_per_gpu,
101+
cfg.data.workers_per_gpu,
102+
cfg.gpus,
103+
dist=False)
104+
]
105+
# put model on gpus
106+
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
107+
# build runner
108+
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
109+
cfg.log_level)
110+
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
111+
cfg.checkpoint_config, cfg.log_config)
112+
113+
if cfg.resume_from:
114+
runner.resume(cfg.resume_from)
115+
elif cfg.load_from:
116+
runner.load_checkpoint(cfg.load_from)
117+
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

mmdet/core/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook
1+
from .dist_utils import allreduce_grads, DistOptimizerHook
22
from .misc import tensor2imgs, unmap, multi_apply
33

44
__all__ = [
5-
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs',
6-
'unmap', 'multi_apply'
5+
'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap',
6+
'multi_apply'
77
]

mmdet/core/utils/dist_utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,11 @@
1-
import os
21
from collections import OrderedDict
32

4-
import torch
5-
import torch.multiprocessing as mp
63
import torch.distributed as dist
74
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
85
_take_tensors)
96
from mmcv.runner import OptimizerHook
107

118

12-
def init_dist(launcher, backend='nccl', **kwargs):
13-
if mp.get_start_method(allow_none=True) is None:
14-
mp.set_start_method('spawn')
15-
if launcher == 'pytorch':
16-
_init_dist_pytorch(backend, **kwargs)
17-
elif launcher == 'mpi':
18-
_init_dist_mpi(backend, **kwargs)
19-
elif launcher == 'slurm':
20-
_init_dist_slurm(backend, **kwargs)
21-
else:
22-
raise ValueError('Invalid launcher type: {}'.format(launcher))
23-
24-
25-
def _init_dist_pytorch(backend, **kwargs):
26-
# TODO: use local_rank instead of rank % num_gpus
27-
rank = int(os.environ['RANK'])
28-
num_gpus = torch.cuda.device_count()
29-
torch.cuda.set_device(rank % num_gpus)
30-
dist.init_process_group(backend=backend, **kwargs)
31-
32-
33-
def _init_dist_mpi(backend, **kwargs):
34-
raise NotImplementedError
35-
36-
37-
def _init_dist_slurm(backend, **kwargs):
38-
raise NotImplementedError
39-
40-
419
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
4210
if bucket_size_mb > 0:
4311
bucket_size_bytes = bucket_size_mb * 1024 * 1024

mmdet/datasets/loader/build_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def build_dataloader(dataset,
1616
imgs_per_gpu,
1717
workers_per_gpu,
18-
num_gpus,
18+
num_gpus=1,
1919
dist=True,
2020
**kwargs):
2121
if dist:

0 commit comments

Comments
 (0)