@@ -123,7 +123,7 @@ def __init__(self, engine_file_path):
123123 self .bindings = bindings
124124 self .batch_size = engine .max_batch_size
125125
126- def infer (self , image_path_batch ):
126+ def infer (self , raw_image_generator ):
127127 threading .Thread .__init__ (self )
128128 # Make self the active context, pushing it on top of the context stack.
129129 self .ctx .push ()
@@ -141,8 +141,8 @@ def infer(self, image_path_batch):
141141 batch_origin_h = []
142142 batch_origin_w = []
143143 batch_input_image = np .empty (shape = [self .batch_size , 3 , self .input_h , self .input_w ])
144- for i , img_path in enumerate (image_path_batch ):
145- input_image , image_raw , origin_h , origin_w = self .preprocess_image (img_path )
144+ for i , image_raw in enumerate (raw_image_generator ):
145+ input_image , image_raw , origin_h , origin_w = self .preprocess_image (image_raw )
146146 batch_image_raw .append (image_raw )
147147 batch_origin_h .append (origin_h )
148148 batch_origin_w .append (origin_w )
@@ -166,7 +166,7 @@ def infer(self, image_path_batch):
166166 # Here we use the first row of output in that batch_size = 1
167167 output = host_outputs [0 ]
168168 # Do postprocess
169- for i , img_path in enumerate ( image_path_batch ):
169+ for i in range ( self . batch_size ):
170170 result_boxes , result_scores , result_classid = self .post_process (
171171 output [i * 6001 : (i + 1 ) * 6001 ], batch_origin_h [i ], batch_origin_w [i ]
172172 )
@@ -180,19 +180,29 @@ def infer(self, image_path_batch):
180180 categories [int (result_classid [j ])], result_scores [j ]
181181 ),
182182 )
183- parent , filename = os .path .split (img_path )
184- save_name = os .path .join ('output' , filename )
185- # Save image
186- cv2 .imwrite (save_name , batch_image_raw [i ])
187- print ('input->{}, time->{:.2f}ms, saving into output/' .format (image_path_batch , (end - start ) * 1000 ))
183+ return batch_image_raw , end - start
188184
189185 def destroy (self ):
190186 # Remove any context from the top of the context stack, deactivating it.
191187 self .ctx .pop ()
188+
189+ def get_raw_image (self , image_path_batch ):
190+ """
191+ description: Read an image from image path
192+ """
193+ for img_path in image_path_batch :
194+ yield cv2 .imread (img_path )
195+
196+ def get_raw_image_zeros (self , image_path_batch = None ):
197+ """
198+ description: Ready data for warmup
199+ """
200+ for _ in range (self .batch_size ):
201+ yield np .zeros ([self .input_h , self .input_w , 3 ], dtype = np .uint8 )
192202
193- def preprocess_image (self , input_image_path ):
203+ def preprocess_image (self , raw_bgr_image ):
194204 """
195- description: Read an image from image path, convert it to RGB,
205+ description: Convert BGR image to RGB,
196206 resize and pad it to target size, normalize to [0,1],
197207 transform to NCHW format.
198208 param:
@@ -203,7 +213,7 @@ def preprocess_image(self, input_image_path):
203213 h: original height
204214 w: original width
205215 """
206- image_raw = cv2 . imread ( input_image_path )
216+ image_raw = raw_bgr_image
207217 h , w , c = image_raw .shape
208218 image = cv2 .cvtColor (image_raw , cv2 .COLOR_BGR2RGB )
209219 # Calculate widht and height and paddings
@@ -305,22 +315,45 @@ def post_process(self, output, origin_h, origin_w):
305315 return result_boxes , result_scores , result_classid
306316
307317
308- class myThread (threading .Thread ):
309- def __init__ (self , func , args ):
318+ class inferThread (threading .Thread ):
319+ def __init__ (self , yolov5_wrapper , image_path_batch ):
310320 threading .Thread .__init__ (self )
311- self .func = func
312- self .args = args
321+ self .yolov5_wrapper = yolov5_wrapper
322+ self .image_path_batch = image_path_batch
313323
314324 def run (self ):
315- self .func (* self .args )
325+ batch_image_raw , use_time = self .yolov5_wrapper .infer (self .yolov5_wrapper .get_raw_image (self .image_path_batch ))
326+ for i , img_path in enumerate (self .image_path_batch ):
327+ parent , filename = os .path .split (img_path )
328+ save_name = os .path .join ('output' , filename )
329+ # Save image
330+ cv2 .imwrite (save_name , batch_image_raw [i ])
331+ print ('input->{}, time->{:.2f}ms, saving into output/' .format (self .image_path_batch , use_time * 1000 ))
332+
333+
334+ class warmUpThread (threading .Thread ):
335+ def __init__ (self , yolov5_wrapper ):
336+ threading .Thread .__init__ (self )
337+ self .yolov5_wrapper = yolov5_wrapper
338+
339+ def run (self ):
340+ batch_image_raw , use_time = self .yolov5_wrapper .infer (self .yolov5_wrapper .get_raw_image_zeros ())
341+ print ('warm_up->{}, time->{:.2f}ms' .format (batch_image_raw [0 ].shape , use_time * 1000 ))
342+
316343
317344
318345if __name__ == "__main__" :
319346 # load custom plugins
320347 PLUGIN_LIBRARY = "build/libmyplugins.so"
321- ctypes .CDLL (PLUGIN_LIBRARY )
322348 engine_file_path = "build/yolov5s.engine"
323349
350+ if len (sys .argv ) > 1 :
351+ engine_file_path = sys .argv [1 ]
352+ if len (sys .argv ) > 2 :
353+ PLUGIN_LIBRARY = sys .argv [2 ]
354+
355+ ctypes .CDLL (PLUGIN_LIBRARY )
356+
324357 # load coco labels
325358
326359 categories = ["person" , "bicycle" , "car" , "motorcycle" , "airplane" , "bus" , "train" , "truck" , "boat" , "traffic light" ,
@@ -338,15 +371,22 @@ def run(self):
338371 os .makedirs ('output/' )
339372 # a YoLov5TRT instance
340373 yolov5_wrapper = YoLov5TRT (engine_file_path )
341- print ('batch size is' , yolov5_wrapper .batch_size )
342- image_dir = "samples/"
343- image_path_batches = get_img_path_batches (yolov5_wrapper .batch_size , image_dir )
344-
345- for batch in image_path_batches :
346- # create a new thread to do inference
347- thread1 = myThread (yolov5_wrapper .infer , [batch ])
348- thread1 .start ()
349- thread1 .join ()
350-
351- # destroy the instance
352- yolov5_wrapper .destroy ()
374+ try :
375+ print ('batch size is' , yolov5_wrapper .batch_size )
376+
377+ image_dir = "samples/"
378+ image_path_batches = get_img_path_batches (yolov5_wrapper .batch_size , image_dir )
379+
380+ for i in range (10 ):
381+ # create a new thread to do warm_up
382+ thread1 = warmUpThread (yolov5_wrapper )
383+ thread1 .start ()
384+ thread1 .join ()
385+ for batch in image_path_batches :
386+ # create a new thread to do inference
387+ thread1 = inferThread (yolov5_wrapper , batch )
388+ thread1 .start ()
389+ thread1 .join ()
390+ finally :
391+ # destroy the instance
392+ yolov5_wrapper .destroy ()
0 commit comments