99from mmcv .runner import load_checkpoint
1010
1111from mmdet .core import get_classes
12+ from mmdet .datasets import replace_ImageToTensor
1213from mmdet .datasets .pipelines import Compose
1314from mmdet .models import build_detector
1415
@@ -104,9 +105,13 @@ def inference_detector(model, img):
104105 # add information into dict
105106 data = dict (img_info = dict (filename = img ), img_prefix = None )
106107 # build the data pipeline
108+ cfg .data .test .pipeline = replace_ImageToTensor (cfg .data .test .pipeline )
107109 test_pipeline = Compose (cfg .data .test .pipeline )
108110 data = test_pipeline (data )
109111 data = collate ([data ], samples_per_gpu = 1 )
112+ # just get the actual data from DataContainer
113+ data ['img_metas' ] = data ['img_metas' ][0 ].data
114+ data ['img' ] = data ['img' ][0 ].data
110115 if next (model .parameters ()).is_cuda :
111116 # scatter to specified GPU
112117 data = scatter (data , [device ])[0 ]
@@ -115,8 +120,6 @@ def inference_detector(model, img):
115120 assert not isinstance (
116121 m , RoIPool
117122 ), 'CPU inference with RoIPool is not supported currently.'
118- # just get the actual data from DataContainer
119- data ['img_metas' ] = data ['img_metas' ][0 ].data
120123
121124 # forward the model
122125 with torch .no_grad ():
0 commit comments