1616import torch
1717import torchvision
1818
19- INPUT_W = 608
20- INPUT_H = 608
21- CONF_THRESH = 0.1
19+ INPUT_W = 640
20+ INPUT_H = 640
21+ CONF_THRESH = 0.5
2222IOU_THRESHOLD = 0.4
2323
2424
@@ -66,7 +66,7 @@ class YoLov5TRT(object):
6666
6767 def __init__ (self , engine_file_path ):
6868 # Create a Context on this device,
69- self .cfx = cuda .Device (0 ).make_context ()
69+ self .ctx = cuda .Device (0 ).make_context ()
7070 stream = cuda .Stream ()
7171 TRT_LOGGER = trt .Logger (trt .Logger .INFO )
7272 runtime = trt .Runtime (TRT_LOGGER )
@@ -111,7 +111,7 @@ def __init__(self, engine_file_path):
111111 def infer (self , input_image_path ):
112112 threading .Thread .__init__ (self )
113113 # Make self the active context, pushing it on top of the context stack.
114- self .cfx .push ()
114+ self .ctx .push ()
115115 # Restore
116116 stream = self .stream
117117 context = self .context
@@ -127,6 +127,7 @@ def infer(self, input_image_path):
127127 )
128128 # Copy input image to host buffer
129129 np .copyto (host_inputs [0 ], input_image .ravel ())
130+ start = time .time ()
130131 # Transfer input data to the GPU.
131132 cuda .memcpy_htod_async (cuda_inputs [0 ], host_inputs [0 ], stream )
132133 # Run inference.
@@ -135,8 +136,9 @@ def infer(self, input_image_path):
135136 cuda .memcpy_dtoh_async (host_outputs [0 ], cuda_outputs [0 ], stream )
136137 # Synchronize the stream
137138 stream .synchronize ()
139+ end = time .time ()
138140 # Remove any context from the top of the context stack, deactivating it.
139- self .cfx .pop ()
141+ self .ctx .pop ()
140142 # Here we use the first row of output in that batch_size = 1
141143 output = host_outputs [0 ]
142144 # Do postprocess
@@ -155,12 +157,13 @@ def infer(self, input_image_path):
155157 )
156158 parent , filename = os .path .split (input_image_path )
157159 save_name = os .path .join (parent , "output_" + filename )
158- # Save image
160+ # Save image
159161 cv2 .imwrite (save_name , image_raw )
162+ print ('{:.2f}ms, saving {}' .format ((end - start ) * 1000 , save_name ))
160163
161164 def destroy (self ):
162165 # Remove any context from the top of the context stack, deactivating it.
163- self .cfx .pop ()
166+ self .ctx .pop ()
164167
165168 def preprocess_image (self , input_image_path ):
166169 """
@@ -308,8 +311,7 @@ def run(self):
308311 # a YoLov5TRT instance
309312 yolov5_wrapper = YoLov5TRT (engine_file_path )
310313
311- # from https://github.com/ultralytics/yolov5/tree/master/inference/images
312- input_image_paths = ["zidane.jpg" , "bus.jpg" ]
314+ input_image_paths = ["samples/zidane.jpg" , "samples/bus.jpg" ]
313315
314316 for input_image_path in input_image_paths :
315317 # create a new thread to do inference
0 commit comments