Skip to content

Commit d6a0b78

Browse files
authored
fix_yolov5 (wang-xinyu#1181)
* fix_yolov5 * add fix * add Rex
1 parent aac418e commit d6a0b78

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

yolov5/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ TensorRTx inference code base for [ultralytics/yolov5](https://github.com/ultral
3030
<a href="https://github.com/triple-Mu"><img src="https://avatars.githubusercontent.com/u/92794867?s=48&v=4" width="40px;" alt=""/></a>
3131
<a href="https://github.com/xiang-wuu"><img src="https://avatars.githubusercontent.com/u/107029401?s=48&v=4" width="40px;" alt=""/></a>
3232
<a href="https://github.com/uyolo1314"><img src="https://avatars.githubusercontent.com/u/101853326?s=48&v=4" width="40px;" alt=""/></a>
33+
<a href="https://github.com/Rex-LK"><img src="https://avatars.githubusercontent.com/u/74702576?s=96&v=4" width="40px;" alt=""/></a>
3334

3435
## Different versions of yolov5
3536

yolov5/yolov5_det_trt.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
CONF_THRESH = 0.5
1818
IOU_THRESHOLD = 0.4
19-
19+
LEN_ALL_RESULT = 38001
20+
LEN_ONE_RESULT = 38
2021

2122
def get_img_path_batches(batch_size, img_dir):
2223
ret = []
@@ -166,7 +167,7 @@ def infer(self, raw_image_generator):
166167
# Do postprocess
167168
for i in range(self.batch_size):
168169
result_boxes, result_scores, result_classid = self.post_process(
169-
output[i * 6001: (i + 1) * 6001], batch_origin_h[i], batch_origin_w[i]
170+
output[i * LEN_ALL_RESULT: (i + 1) * LEN_ALL_RESULT], batch_origin_h[i], batch_origin_w[i]
170171
)
171172
# Draw rectangles and labels on the original image
172173
for j in range(len(result_boxes)):
@@ -289,7 +290,8 @@ def post_process(self, output, origin_h, origin_w):
289290
# Get the num of boxes detected
290291
num = int(output[0])
291292
# Reshape to a two dimentional ndarray
292-
pred = np.reshape(output[1:], (-1, 6))[:num, :]
293+
pred = np.reshape(output[1:], (-1, LEN_ONE_RESULT))[:num, :]
294+
pred = pred[:, :6]
293295
# Do nms
294296
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
295297
result_boxes = boxes[:, :4] if len(boxes) else np.array([])

0 commit comments

Comments
 (0)