Skip to content

Commit 54b54d8

Browse files
authored
Merge pull request open-mmlab#19 from hellock/dev
Update inference APIs
2 parents abc440f + 459d5eb commit 54b54d8

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

mmdet/apis/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .env import init_dist, get_root_logger, set_random_seed
22
from .train import train_detector
3-
from .inference import inference_detector
3+
from .inference import inference_detector, show_result
44

55
__all__ = [
66
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
7-
'inference_detector'
7+
'inference_detector', 'show_result'
88
]

mmdet/apis/inference.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,29 @@ def _prepare_data(img, img_transform, cfg, device):
2323
return dict(img=[img], img_meta=[img_meta])
2424

2525

26-
def inference_detector(model, imgs, cfg, device='cuda:0'):
26+
def _inference_single(model, img, img_transform, cfg, device):
27+
img = mmcv.imread(img)
28+
data = _prepare_data(img, img_transform, cfg, device)
29+
with torch.no_grad():
30+
result = model(return_loss=False, rescale=True, **data)
31+
return result
32+
33+
34+
def _inference_generator(model, imgs, img_transform, cfg, device):
35+
for img in imgs:
36+
yield _inference_single(model, img, img_transform, cfg, device)
2737

28-
imgs = imgs if isinstance(imgs, list) else [imgs]
38+
39+
def inference_detector(model, imgs, cfg, device='cuda:0'):
2940
img_transform = ImageTransform(
3041
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
3142
model = model.to(device)
3243
model.eval()
33-
for img in imgs:
34-
img = mmcv.imread(img)
35-
data = _prepare_data(img, img_transform, cfg, device)
36-
with torch.no_grad():
37-
result = model(return_loss=False, rescale=True, **data)
38-
yield result
44+
45+
if not isinstance(imgs, list):
46+
return _inference_single(model, imgs, img_transform, cfg, device)
47+
else:
48+
return _inference_generator(model, imgs, img_transform, cfg, device)
3949

4050

4151
def show_result(img, result, dataset='coco', score_thr=0.3):
@@ -46,6 +56,7 @@ def show_result(img, result, dataset='coco', score_thr=0.3):
4656
]
4757
labels = np.concatenate(labels)
4858
bboxes = np.vstack(result)
59+
img = mmcv.imread(img)
4960
mmcv.imshow_det_bboxes(
5061
img.copy(),
5162
bboxes,

mmdet/models/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33

44
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
5-
mask_heads, detectors)
5+
mask_heads)
66

77
__all__ = [
88
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
@@ -48,4 +48,5 @@ def build_mask_head(cfg):
4848

4949

5050
def build_detector(cfg, train_cfg=None, test_cfg=None):
51+
from . import detectors
5152
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))

mmdet/models/rpn_heads/rpn_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(self,
4848
self.anchor_scales = anchor_scales
4949
self.anchor_ratios = anchor_ratios
5050
self.anchor_strides = anchor_strides
51-
self.anchor_base_sizes = anchor_strides.copy(
52-
) if anchor_base_sizes is None else anchor_base_sizes
51+
self.anchor_base_sizes = list(
52+
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
5353
self.target_means = target_means
5454
self.target_stds = target_stds
5555
self.use_sigmoid_cls = use_sigmoid_cls

0 commit comments

Comments
 (0)