1616import torch
1717import torchvision
1818
19-
2019INPUT_W = 608
2120INPUT_H = 608
22- CONF_THRESH = 0.5
21+ CONF_THRESH = 0.1
2322IOU_THRESHOLD = 0.4
2423
2524
2625def plot_one_box (x , img , color = None , label = None , line_thickness = None ):
27- '''
26+ """
2827 description: Plots one bounding box on image img,
2928 this function comes from YoLov5 project.
3029 param:
@@ -36,9 +35,10 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
3635 return:
3736 no return
3837
39- '''
40- tl = line_thickness or round (
41- 0.002 * (img .shape [0 ] + img .shape [1 ]) / 2 ) + 1 # line/font thickness
38+ """
39+ tl = (
40+ line_thickness or round (0.002 * (img .shape [0 ] + img .shape [1 ]) / 2 ) + 1
41+ ) # line/font thickness
4242 color = color or [random .randint (0 , 255 ) for _ in range (3 )]
4343 c1 , c2 = (int (x [0 ]), int (x [1 ])), (int (x [2 ]), int (x [3 ]))
4444 cv2 .rectangle (img , c1 , c2 , color , thickness = tl , lineType = cv2 .LINE_AA )
@@ -47,14 +47,23 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
4747 t_size = cv2 .getTextSize (label , 0 , fontScale = tl / 3 , thickness = tf )[0 ]
4848 c2 = c1 [0 ] + t_size [0 ], c1 [1 ] - t_size [1 ] - 3
4949 cv2 .rectangle (img , c1 , c2 , color , - 1 , cv2 .LINE_AA ) # filled
50- cv2 .putText (img , label , (c1 [0 ], c1 [1 ] - 2 ), 0 , tl / 3 ,
51- [225 , 255 , 255 ], thickness = tf , lineType = cv2 .LINE_AA )
50+ cv2 .putText (
51+ img ,
52+ label ,
53+ (c1 [0 ], c1 [1 ] - 2 ),
54+ 0 ,
55+ tl / 3 ,
56+ [225 , 255 , 255 ],
57+ thickness = tf ,
58+ lineType = cv2 .LINE_AA ,
59+ )
5260
5361
5462class YoLov5TRT (object ):
55- '''
63+ """
5664 description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops.
57- '''
65+ """
66+
5867 def __init__ (self , engine_file_path ):
5968 # Create a Context on this device,
6069 self .cfx = cuda .Device (0 ).make_context ()
@@ -74,8 +83,7 @@ def __init__(self, engine_file_path):
7483 bindings = []
7584
7685 for binding in engine :
77- size = trt .volume (engine .get_binding_shape (
78- binding )) * engine .max_batch_size
86+ size = trt .volume (engine .get_binding_shape (binding )) * engine .max_batch_size
7987 dtype = trt .nptype (engine .get_binding_dtype (binding ))
8088 # Allocate host and device buffers
8189 host_mem = cuda .pagelocked_empty (size , dtype )
@@ -102,7 +110,7 @@ def __init__(self, engine_file_path):
102110
103111 def infer (self , input_image_path ):
104112 threading .Thread .__init__ (self )
105- # Make self the active context, pushing it on top of the context stack.
113+ # Make self the active context, pushing it on top of the context stack.
106114 self .cfx .push ()
107115 # Restore
108116 stream = self .stream
@@ -115,7 +123,8 @@ def infer(self, input_image_path):
115123 bindings = self .bindings
116124 # Do image preprocess
117125 input_image , image_raw , origin_h , origin_w = self .preprocess_image (
118- input_image_path )
126+ input_image_path
127+ )
119128 # Copy input image to host buffer
120129 np .copyto (host_inputs [0 ], input_image .ravel ())
121130 # Transfer input data to the GPU.
@@ -132,23 +141,29 @@ def infer(self, input_image_path):
132141 output = host_outputs [0 ]
133142 # Do postprocess
134143 result_boxes , result_scores , result_classid = self .post_process (
135- output , origin_h , origin_w )
144+ output , origin_h , origin_w
145+ )
136146 # Draw rectangles and labels on the original image
137147 for i in range (len (result_boxes )):
138148 box = result_boxes [i ]
139- plot_one_box (box , image_raw , label = "{}:{:.2f}" .format (
140- categories [int (result_classid [i ])], result_scores [i ]))
149+ plot_one_box (
150+ box ,
151+ image_raw ,
152+ label = "{}:{:.2f}" .format (
153+ categories [int (result_classid [i ])], result_scores [i ]
154+ ),
155+ )
141156 parent , filename = os .path .split (input_image_path )
142- save_name = os .path .join (parent , "output_" + filename )
143- # Save image
157+ save_name = os .path .join (parent , "output_" + filename )
158+ # Save image
144159 cv2 .imwrite (save_name , image_raw )
145160
146161 def destory (self ):
147162 # Remove any context from the top of the context stack, deactivating it.
148163 self .cfx .pop ()
149164
150165 def preprocess_image (self , input_image_path ):
151- '''
166+ """
152167 description: Read an image from image path, convert it to RGB,
153168 resize and pad it to target size, normalize to [0,1],
154169 transform to NCHW format.
@@ -159,7 +174,7 @@ def preprocess_image(self, input_image_path):
159174 image_raw: the original image
160175 h: original height
161176 w: original width
162- '''
177+ """
163178 image_raw = cv2 .imread (input_image_path )
164179 h , w , c = image_raw .shape
165180 image = cv2 .cvtColor (image_raw , cv2 .COLOR_BGR2RGB )
@@ -169,18 +184,21 @@ def preprocess_image(self, input_image_path):
169184 if r_h > r_w :
170185 tw = INPUT_W
171186 th = int (r_w * h )
172- tx = 0
173- ty = int ((INPUT_H - th ) / 2 )
187+ tx1 = tx2 = 0
188+ ty1 = int ((INPUT_H - th ) / 2 )
189+ ty2 = INPUT_H - th - ty1
174190 else :
175191 tw = int (r_h * w )
176192 th = INPUT_H
177- tx = int ((INPUT_W - tw ) / 2 )
178- ty = 0
193+ tx1 = int ((INPUT_W - tw ) / 2 )
194+ tx2 = INPUT_W - tw - tx1
195+ ty1 = ty2 = 0
179196 # Resize the image with long side while maintaining ratio
180197 image = cv2 .resize (image , (tw , th ))
181198 # Pad the short side with (128,128,128)
182199 image = cv2 .copyMakeBorder (
183- image , ty , ty , tx , tx , cv2 .BORDER_CONSTANT , (128 , 128 , 128 ))
200+ image , ty1 , ty2 , tx1 , tx2 , cv2 .BORDER_CONSTANT , (128 , 128 , 128 )
201+ )
184202 image = image .astype (np .float32 )
185203 # Normalize to [0,1]
186204 image /= 255.0
@@ -193,36 +211,35 @@ def preprocess_image(self, input_image_path):
193211 return image , image_raw , h , w
194212
195213 def xywh2xyxy (self , origin_h , origin_w , x ):
196- '''
214+ """
197215 description: Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
198216 param:
199217 origin_h: height of original image
200218 origin_w: width of original image
201219 x: A boxes tensor, each row is a box [center_x, center_y, w, h]
202220 return:
203221 y: A boxes tensor, each row is a box [x1, y1, x2, y2]
204- '''
205- y = torch .zeros_like (x ) if isinstance (
206- x , torch .Tensor ) else np .zeros_like (x )
222+ """
223+ y = torch .zeros_like (x ) if isinstance (x , torch .Tensor ) else np .zeros_like (x )
207224 r_w = INPUT_W / origin_w
208225 r_h = INPUT_H / origin_h
209226 if r_h > r_w :
210- y [:, 0 ] = x [:, 0 ] - x [:, 2 ]/ 2
211- y [:, 2 ] = x [:, 0 ] + x [:, 2 ]/ 2
212- y [:, 1 ] = x [:, 1 ] - x [:, 3 ]/ 2 - (INPUT_H - r_w * origin_h ) / 2
213- y [:, 3 ] = x [:, 1 ] + x [:, 3 ]/ 2 - (INPUT_H - r_w * origin_h ) / 2
227+ y [:, 0 ] = x [:, 0 ] - x [:, 2 ] / 2
228+ y [:, 2 ] = x [:, 0 ] + x [:, 2 ] / 2
229+ y [:, 1 ] = x [:, 1 ] - x [:, 3 ] / 2 - (INPUT_H - r_w * origin_h ) / 2
230+ y [:, 3 ] = x [:, 1 ] + x [:, 3 ] / 2 - (INPUT_H - r_w * origin_h ) / 2
214231 y /= r_w
215232 else :
216- y [:, 0 ] = x [:, 0 ] - x [:, 2 ]/ 2 - (INPUT_W - r_h * origin_w ) / 2
217- y [:, 2 ] = x [:, 0 ] + x [:, 2 ]/ 2 - (INPUT_W - r_h * origin_w ) / 2
218- y [:, 1 ] = x [:, 1 ] - x [:, 3 ]/ 2
219- y [:, 3 ] = x [:, 1 ] + x [:, 3 ]/ 2
233+ y [:, 0 ] = x [:, 0 ] - x [:, 2 ] / 2 - (INPUT_W - r_h * origin_w ) / 2
234+ y [:, 2 ] = x [:, 0 ] + x [:, 2 ] / 2 - (INPUT_W - r_h * origin_w ) / 2
235+ y [:, 1 ] = x [:, 1 ] - x [:, 3 ] / 2
236+ y [:, 3 ] = x [:, 1 ] + x [:, 3 ] / 2
220237 y /= r_h
221238
222239 return y
223240
224241 def post_process (self , output , origin_h , origin_w ):
225- '''
242+ """
226243 description: postprocess the prediction
227244 param:
228245 output: A tensor likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
@@ -232,7 +249,7 @@ def post_process(self, output, origin_h, origin_w):
232249 result_boxes: finally boxes, a boxes tensor, each row is a box [x1, y1, x2, y2]
233250 result_scores: finally scores, a tensor, each element is the score correspoing to box
234251 result_classid: finally classid, a tensor, each element is the classid correspoing to box
235- '''
252+ """
236253 # Get the num of boxes detected
237254 num = int (output [0 ])
238255 # Reshape to a two dimentional ndarray
@@ -253,8 +270,7 @@ def post_process(self, output, origin_h, origin_w):
253270 # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
254271 boxes = self .xywh2xyxy (origin_h , origin_w , boxes )
255272 # Do nms
256- indices = torchvision .ops .nms (
257- boxes , scores , iou_threshold = IOU_THRESHOLD ).cpu ()
273+ indices = torchvision .ops .nms (boxes , scores , iou_threshold = IOU_THRESHOLD ).cpu ()
258274 result_boxes = boxes [indices , :].cpu ()
259275 result_scores = scores [indices ].cpu ()
260276 result_classid = classid [indices ].cpu ()
@@ -271,30 +287,35 @@ def run(self):
271287 self .func (* self .args )
272288
273289
274- if __name__ == ' __main__' :
290+ if __name__ == " __main__" :
275291 # load custom plugins
276- PLUGIN_LIBRARY = ' build/libmyplugins.so'
292+ PLUGIN_LIBRARY = " build/libmyplugins.so"
277293 ctypes .CDLL (PLUGIN_LIBRARY )
278294 engine_file_path = "build/yolov5s.engine"
279295
280296 # load coco labels
281- coco_labels = "coco_labels.txt"
282- categories = []
283- with open (coco_labels , "r" ) as f :
284- for line in f :
285- categories .append (line .strip ())
297+
298+ categories = ["person" , "bicycle" , "car" , "motorcycle" , "airplane" , "bus" , "train" , "truck" , "boat" , "traffic light" ,
299+ "fire hydrant" , "stop sign" , "parking meter" , "bench" , "bird" , "cat" , "dog" , "horse" , "sheep" , "cow" ,
300+ "elephant" , "bear" , "zebra" , "giraffe" , "backpack" , "umbrella" , "handbag" , "tie" , "suitcase" , "frisbee" ,
301+ "skis" , "snowboard" , "sports ball" , "kite" , "baseball bat" , "baseball glove" , "skateboard" , "surfboard" ,
302+ "tennis racket" , "bottle" , "wine glass" , "cup" , "fork" , "knife" , "spoon" , "bowl" , "banana" , "apple" ,
303+ "sandwich" , "orange" , "broccoli" , "carrot" , "hot dog" , "pizza" , "donut" , "cake" , "chair" , "couch" ,
304+ "potted plant" , "bed" , "dining table" , "toilet" , "tv" , "laptop" , "mouse" , "remote" , "keyboard" , "cell phone" ,
305+ "microwave" , "oven" , "toaster" , "sink" , "refrigerator" , "book" , "clock" , "vase" , "scissors" , "teddy bear" ,
306+ "hair drier" , "toothbrush" ]
286307
287308 # a YoLov5TRT instance
288- yolov5_warpper = YoLov5TRT (engine_file_path )
309+ yolov5_wrapper = YoLov5TRT (engine_file_path )
289310
311+ # from https://github.com/ultralytics/yolov5/tree/master/inference/images
290312 input_image_paths = ["zidane.jpg" , "bus.jpg" ]
291313
292314 for input_image_path in input_image_paths :
293315 # create a new thread to do inference
294- thread1 = myThread (yolov5_warpper .infer , [input_image_path ])
316+ thread1 = myThread (yolov5_wrapper .infer , [input_image_path ])
295317 thread1 .start ()
296318 thread1 .join ()
297319
298320 # destory the instance
299- yolov5_warpper .destory ()
300-
321+ yolov5_wrapper .destory ()
0 commit comments