Skip to content

Commit 7c4476f

Browse files
authored
A pure numpy version for yolov5_trt.py (wang-xinyu#700)
* This code provides a pure numpy manner to write nms operation in yolov5_trt.py. I think it can offer simple and convenient TensorRT experience for yolov5 * delete torch related codes
1 parent 32049c6 commit 7c4476f

File tree

1 file changed

+82
-29
lines changed

1 file changed

+82
-29
lines changed

yolov5/yolov5_trt.py

Lines changed: 82 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
import pycuda.autoinit
1414
import pycuda.driver as cuda
1515
import tensorrt as trt
16-
import torch
17-
import torchvision
1816

1917
CONF_THRESH = 0.5
2018
IOU_THRESHOLD = 0.4
@@ -254,11 +252,11 @@ def xywh2xyxy(self, origin_h, origin_w, x):
254252
param:
255253
origin_h: height of original image
256254
origin_w: width of original image
257-
x: A boxes tensor, each row is a box [center_x, center_y, w, h]
255+
x: A boxes numpy, each row is a box [center_x, center_y, w, h]
258256
return:
259-
y: A boxes tensor, each row is a box [x1, y1, x2, y2]
257+
y: A boxes numpy, each row is a box [x1, y1, x2, y2]
260258
"""
261-
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
259+
y = np.zeros_like(x)
262260
r_w = self.input_w / origin_w
263261
r_h = self.input_h / origin_h
264262
if r_h > r_w:
@@ -280,40 +278,95 @@ def post_process(self, output, origin_h, origin_w):
280278
"""
281279
description: postprocess the prediction
282280
param:
283-
output: A tensor likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
281+
output: A numpy likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
284282
origin_h: height of original image
285283
origin_w: width of original image
286284
return:
287-
result_boxes: finally boxes, a boxes tensor, each row is a box [x1, y1, x2, y2]
288-
result_scores: finally scores, a tensor, each element is the score correspoing to box
289-
result_classid: finally classid, a tensor, each element is the classid correspoing to box
285+
result_boxes: finally boxes, a boxes numpy, each row is a box [x1, y1, x2, y2]
286+
result_scores: finally scores, a numpy, each element is the score correspoing to box
287+
result_classid: finally classid, a numpy, each element is the classid correspoing to box
290288
"""
291289
# Get the num of boxes detected
292290
num = int(output[0])
293291
# Reshape to a two dimentional ndarray
294292
pred = np.reshape(output[1:], (-1, 6))[:num, :]
295-
# to a torch Tensor
296-
pred = torch.Tensor(pred).cuda()
297-
# Get the boxes
298-
boxes = pred[:, :4]
299-
# Get the scores
300-
scores = pred[:, 4]
301-
# Get the classid
302-
classid = pred[:, 5]
303-
# Choose those boxes that score > CONF_THRESH
304-
si = scores > CONF_THRESH
305-
boxes = boxes[si, :]
306-
scores = scores[si]
307-
classid = classid[si]
308-
# Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
309-
boxes = self.xywh2xyxy(origin_h, origin_w, boxes)
310293
# Do nms
311-
indices = torchvision.ops.nms(boxes, scores, iou_threshold=IOU_THRESHOLD).cpu()
312-
result_boxes = boxes[indices, :].cpu()
313-
result_scores = scores[indices].cpu()
314-
result_classid = classid[indices].cpu()
294+
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
295+
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
296+
result_scores = boxes[:, 4] if len(boxes) else np.array([])
297+
result_classid = boxes[:, 5] if len(boxes) else np.array([])
315298
return result_boxes, result_scores, result_classid
316299

300+
def bbox_iou(self, box1, box2, x1y1x2y2=True):
301+
"""
302+
description: compute the IoU of two bounding boxes
303+
param:
304+
box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
305+
box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
306+
x1y1x2y2: select the coordinate format
307+
return:
308+
iou: computed iou
309+
"""
310+
if not x1y1x2y2:
311+
# Transform from center and width to exact coordinates
312+
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
313+
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
314+
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
315+
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
316+
else:
317+
# Get the coordinates of bounding boxes
318+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
319+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
320+
321+
# Get the coordinates of the intersection rectangle
322+
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
323+
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
324+
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
325+
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
326+
# Intersection area
327+
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) * \
328+
np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None)
329+
# Union Area
330+
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
331+
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
332+
333+
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
334+
335+
return iou
336+
337+
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
338+
"""
339+
description: Removes detections with lower object confidence score than 'conf_thres' and performs
340+
Non-Maximum Suppression to further filter detections.
341+
param:
342+
prediction: detections, (x1, y1, x2, y2, conf, cls_id)
343+
origin_h: original image height
344+
origin_w: original image width
345+
conf_thres: a confidence threshold to filter detections
346+
nms_thres: a iou threshold to filter detections
347+
return:
348+
boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id)
349+
"""
350+
# Get the boxes that score > CONF_THRESH
351+
boxes = prediction[prediction[:, 4] >= conf_thres]
352+
# Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
353+
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
354+
# Object confidence
355+
confs = boxes[:, 4]
356+
# Sort by the confs
357+
boxes = boxes[np.argsort(-confs)]
358+
# Perform non-maximum suppression
359+
keep_boxes = []
360+
while boxes.shape[0]:
361+
large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
362+
label_match = boxes[0, -1] == boxes[:, -1]
363+
# Indices of boxes with lower confidence scores, large IOUs and matching labels
364+
invalid = large_overlap & label_match
365+
keep_boxes += [boxes[0]]
366+
boxes = boxes[~invalid]
367+
boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
368+
return boxes
369+
317370

318371
class inferThread(threading.Thread):
319372
def __init__(self, yolov5_wrapper, image_path_batch):
@@ -343,7 +396,7 @@ def run(self):
343396

344397

345398
if __name__ == "__main__":
346-
# load custom plugins
399+
# load custom plugin and engine
347400
PLUGIN_LIBRARY = "build/libmyplugins.so"
348401
engine_file_path = "build/yolov5s.engine"
349402

0 commit comments

Comments
 (0)