Skip to content

Commit 9cca2bd

Browse files
committed
bug fix
1 parent c20e2f9 commit 9cca2bd

File tree

1 file changed

+74
-53
lines changed

1 file changed

+74
-53
lines changed

yolov5/yolov5_trt.py

Lines changed: 74 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
import torch
1717
import torchvision
1818

19-
2019
INPUT_W = 608
2120
INPUT_H = 608
22-
CONF_THRESH = 0.5
21+
CONF_THRESH = 0.1
2322
IOU_THRESHOLD = 0.4
2423

2524

2625
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
27-
'''
26+
"""
2827
description: Plots one bounding box on image img,
2928
this function comes from YoLov5 project.
3029
param:
@@ -36,9 +35,10 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
3635
return:
3736
no return
3837
39-
'''
40-
tl = line_thickness or round(
41-
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
38+
"""
39+
tl = (
40+
line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
41+
) # line/font thickness
4242
color = color or [random.randint(0, 255) for _ in range(3)]
4343
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
4444
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
@@ -47,14 +47,23 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
4747
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
4848
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
4949
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
50-
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3,
51-
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
50+
cv2.putText(
51+
img,
52+
label,
53+
(c1[0], c1[1] - 2),
54+
0,
55+
tl / 3,
56+
[225, 255, 255],
57+
thickness=tf,
58+
lineType=cv2.LINE_AA,
59+
)
5260

5361

5462
class YoLov5TRT(object):
55-
'''
63+
"""
5664
description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops.
57-
'''
65+
"""
66+
5867
def __init__(self, engine_file_path):
5968
# Create a Context on this device,
6069
self.cfx = cuda.Device(0).make_context()
@@ -74,8 +83,7 @@ def __init__(self, engine_file_path):
7483
bindings = []
7584

7685
for binding in engine:
77-
size = trt.volume(engine.get_binding_shape(
78-
binding)) * engine.max_batch_size
86+
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
7987
dtype = trt.nptype(engine.get_binding_dtype(binding))
8088
# Allocate host and device buffers
8189
host_mem = cuda.pagelocked_empty(size, dtype)
@@ -102,7 +110,7 @@ def __init__(self, engine_file_path):
102110

103111
def infer(self, input_image_path):
104112
threading.Thread.__init__(self)
105-
# Make self the active context, pushing it on top of the context stack.
113+
# Make self the active context, pushing it on top of the context stack.
106114
self.cfx.push()
107115
# Restore
108116
stream = self.stream
@@ -115,7 +123,8 @@ def infer(self, input_image_path):
115123
bindings = self.bindings
116124
# Do image preprocess
117125
input_image, image_raw, origin_h, origin_w = self.preprocess_image(
118-
input_image_path)
126+
input_image_path
127+
)
119128
# Copy input image to host buffer
120129
np.copyto(host_inputs[0], input_image.ravel())
121130
# Transfer input data to the GPU.
@@ -132,23 +141,29 @@ def infer(self, input_image_path):
132141
output = host_outputs[0]
133142
# Do postprocess
134143
result_boxes, result_scores, result_classid = self.post_process(
135-
output, origin_h, origin_w)
144+
output, origin_h, origin_w
145+
)
136146
# Draw rectangles and labels on the original image
137147
for i in range(len(result_boxes)):
138148
box = result_boxes[i]
139-
plot_one_box(box, image_raw, label="{}:{:.2f}".format(
140-
categories[int(result_classid[i])], result_scores[i]))
149+
plot_one_box(
150+
box,
151+
image_raw,
152+
label="{}:{:.2f}".format(
153+
categories[int(result_classid[i])], result_scores[i]
154+
),
155+
)
141156
parent, filename = os.path.split(input_image_path)
142-
save_name = os.path.join(parent, "output_"+filename)
143-
# Save image
157+
save_name = os.path.join(parent, "output_" + filename)
158+
#  Save image
144159
cv2.imwrite(save_name, image_raw)
145160

146161
def destory(self):
147162
# Remove any context from the top of the context stack, deactivating it.
148163
self.cfx.pop()
149164

150165
def preprocess_image(self, input_image_path):
151-
'''
166+
"""
152167
description: Read an image from image path, convert it to RGB,
153168
resize and pad it to target size, normalize to [0,1],
154169
transform to NCHW format.
@@ -159,7 +174,7 @@ def preprocess_image(self, input_image_path):
159174
image_raw: the original image
160175
h: original height
161176
w: original width
162-
'''
177+
"""
163178
image_raw = cv2.imread(input_image_path)
164179
h, w, c = image_raw.shape
165180
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
@@ -169,18 +184,21 @@ def preprocess_image(self, input_image_path):
169184
if r_h > r_w:
170185
tw = INPUT_W
171186
th = int(r_w * h)
172-
tx = 0
173-
ty = int((INPUT_H - th) / 2)
187+
tx1 = tx2 = 0
188+
ty1 = int((INPUT_H - th) / 2)
189+
ty2 = INPUT_H - th - ty1
174190
else:
175191
tw = int(r_h * w)
176192
th = INPUT_H
177-
tx = int((INPUT_W - tw) / 2)
178-
ty = 0
193+
tx1 = int((INPUT_W - tw) / 2)
194+
tx2 = INPUT_W - tw - tx1
195+
ty1 = ty2 = 0
179196
# Resize the image with long side while maintaining ratio
180197
image = cv2.resize(image, (tw, th))
181198
# Pad the short side with (128,128,128)
182199
image = cv2.copyMakeBorder(
183-
image, ty, ty, tx, tx, cv2.BORDER_CONSTANT, (128, 128, 128))
200+
image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128)
201+
)
184202
image = image.astype(np.float32)
185203
# Normalize to [0,1]
186204
image /= 255.0
@@ -193,36 +211,35 @@ def preprocess_image(self, input_image_path):
193211
return image, image_raw, h, w
194212

195213
def xywh2xyxy(self, origin_h, origin_w, x):
196-
'''
214+
"""
197215
description: Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
198216
param:
199217
origin_h: height of original image
200218
origin_w: width of original image
201219
x: A boxes tensor, each row is a box [center_x, center_y, w, h]
202220
return:
203221
y: A boxes tensor, each row is a box [x1, y1, x2, y2]
204-
'''
205-
y = torch.zeros_like(x) if isinstance(
206-
x, torch.Tensor) else np.zeros_like(x)
222+
"""
223+
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
207224
r_w = INPUT_W / origin_w
208225
r_h = INPUT_H / origin_h
209226
if r_h > r_w:
210-
y[:, 0] = x[:, 0] - x[:, 2]/2
211-
y[:, 2] = x[:, 0] + x[:, 2]/2
212-
y[:, 1] = x[:, 1] - x[:, 3]/2 - (INPUT_H - r_w * origin_h) / 2
213-
y[:, 3] = x[:, 1] + x[:, 3]/2 - (INPUT_H - r_w * origin_h) / 2
227+
y[:, 0] = x[:, 0] - x[:, 2] / 2
228+
y[:, 2] = x[:, 0] + x[:, 2] / 2
229+
y[:, 1] = x[:, 1] - x[:, 3] / 2 - (INPUT_H - r_w * origin_h) / 2
230+
y[:, 3] = x[:, 1] + x[:, 3] / 2 - (INPUT_H - r_w * origin_h) / 2
214231
y /= r_w
215232
else:
216-
y[:, 0] = x[:, 0] - x[:, 2]/2 - (INPUT_W - r_h * origin_w) / 2
217-
y[:, 2] = x[:, 0] + x[:, 2]/2 - (INPUT_W - r_h * origin_w) / 2
218-
y[:, 1] = x[:, 1] - x[:, 3]/2
219-
y[:, 3] = x[:, 1] + x[:, 3]/2
233+
y[:, 0] = x[:, 0] - x[:, 2] / 2 - (INPUT_W - r_h * origin_w) / 2
234+
y[:, 2] = x[:, 0] + x[:, 2] / 2 - (INPUT_W - r_h * origin_w) / 2
235+
y[:, 1] = x[:, 1] - x[:, 3] / 2
236+
y[:, 3] = x[:, 1] + x[:, 3] / 2
220237
y /= r_h
221238

222239
return y
223240

224241
def post_process(self, output, origin_h, origin_w):
225-
'''
242+
"""
226243
description: postprocess the prediction
227244
param:
228245
output: A tensor likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
@@ -232,7 +249,7 @@ def post_process(self, output, origin_h, origin_w):
232249
result_boxes: finally boxes, a boxes tensor, each row is a box [x1, y1, x2, y2]
233250
result_scores: finally scores, a tensor, each element is the score correspoing to box
234251
result_classid: finally classid, a tensor, each element is the classid correspoing to box
235-
'''
252+
"""
236253
# Get the num of boxes detected
237254
num = int(output[0])
238255
# Reshape to a two dimentional ndarray
@@ -253,8 +270,7 @@ def post_process(self, output, origin_h, origin_w):
253270
# Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
254271
boxes = self.xywh2xyxy(origin_h, origin_w, boxes)
255272
# Do nms
256-
indices = torchvision.ops.nms(
257-
boxes, scores, iou_threshold=IOU_THRESHOLD).cpu()
273+
indices = torchvision.ops.nms(boxes, scores, iou_threshold=IOU_THRESHOLD).cpu()
258274
result_boxes = boxes[indices, :].cpu()
259275
result_scores = scores[indices].cpu()
260276
result_classid = classid[indices].cpu()
@@ -271,30 +287,35 @@ def run(self):
271287
self.func(*self.args)
272288

273289

274-
if __name__ == '__main__':
290+
if __name__ == "__main__":
275291
# load custom plugins
276-
PLUGIN_LIBRARY = 'build/libmyplugins.so'
292+
PLUGIN_LIBRARY = "build/libmyplugins.so"
277293
ctypes.CDLL(PLUGIN_LIBRARY)
278294
engine_file_path = "build/yolov5s.engine"
279295

280296
# load coco labels
281-
coco_labels = "coco_labels.txt"
282-
categories = []
283-
with open(coco_labels, "r") as f:
284-
for line in f:
285-
categories.append(line.strip())
297+
298+
categories = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
299+
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
300+
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
301+
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
302+
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
303+
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
304+
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
305+
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
306+
"hair drier", "toothbrush"]
286307

287308
# a YoLov5TRT instance
288-
yolov5_warpper = YoLov5TRT(engine_file_path)
309+
yolov5_wrapper = YoLov5TRT(engine_file_path)
289310

311+
# from https://github.com/ultralytics/yolov5/tree/master/inference/images
290312
input_image_paths = ["zidane.jpg", "bus.jpg"]
291313

292314
for input_image_path in input_image_paths:
293315
# create a new thread to do inference
294-
thread1 = myThread(yolov5_warpper.infer, [input_image_path])
316+
thread1 = myThread(yolov5_wrapper.infer, [input_image_path])
295317
thread1.start()
296318
thread1.join()
297319

298320
# destory the instance
299-
yolov5_warpper.destory()
300-
321+
yolov5_wrapper.destory()

0 commit comments

Comments
 (0)