Skip to content

Commit dfd0a92

Browse files
committed
Support DefaultFormatBundle in image_demo
1 parent 3415ae9 commit dfd0a92

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

mmdet/apis/inference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mmcv.runner import load_checkpoint
1010

1111
from mmdet.core import get_classes
12+
from mmdet.datasets import replace_ImageToTensor
1213
from mmdet.datasets.pipelines import Compose
1314
from mmdet.models import build_detector
1415

@@ -104,9 +105,13 @@ def inference_detector(model, img):
104105
# add information into dict
105106
data = dict(img_info=dict(filename=img), img_prefix=None)
106107
# build the data pipeline
108+
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
107109
test_pipeline = Compose(cfg.data.test.pipeline)
108110
data = test_pipeline(data)
109111
data = collate([data], samples_per_gpu=1)
112+
# just get the actual data from DataContainer
113+
data['img_metas'] = data['img_metas'][0].data
114+
data['img'] = data['img'][0].data
110115
if next(model.parameters()).is_cuda:
111116
# scatter to specified GPU
112117
data = scatter(data, [device])[0]
@@ -115,8 +120,6 @@ def inference_detector(model, img):
115120
assert not isinstance(
116121
m, RoIPool
117122
), 'CPU inference with RoIPool is not supported currently.'
118-
# just get the actual data from DataContainer
119-
data['img_metas'] = data['img_metas'][0].data
120123

121124
# forward the model
122125
with torch.no_grad():

0 commit comments

Comments
 (0)