Skip to content

Commit e56ec6e

Browse files
committed
fix yololayer
1 parent c668800 commit e56ec6e

File tree

4 files changed

+38
-25
lines changed

4 files changed

+38
-25
lines changed

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ Following models are implemented, each one also has a readme inside.
3434
|[mnasnet](./mnasnet)| MNASNet with depth multiplier of 0.5 from the paper |
3535
|[mobilenet](./mobilenetv2)| MobileNet V2, V3-small, V3-large. |
3636
|[resnet](./resnet)| resnet-18, resnet-50 and resnext50-32x4d are implemented |
37-
|[senet](./senet)| se_resnet50 |
37+
|[senet](./senet)| se-resnet50 |
3838
|[shufflenet](./shufflenetv2)| ShuffleNetV2 with 0.5x output channels |
3939
|[squeezenet](./squeezenet)| SqueezeNet 1.1 model |
4040
|[vgg](./vgg)| VGG 11-layer model |
4141
|[yolov3](./yolov3)| darknet-53, weights from yolov3 authors |
42+
|[yolov3-spp](./yolov3-spp)| darknet-53, weights from [ultralytics/yolov3](https://github.com/ultralytics/yolov3) |
4243

4344
## Tricky Operations
4445

@@ -54,7 +55,18 @@ Some tricky operations encountered in these models, already solved, but might ha
5455
|channel shuffle| use two shuffle layers to implement `channel_shuffle`, see shufflenet. |
5556
|adaptive pool| use fixed input dimension, and use regular average pooling, see shufflenet. |
5657
|leaky relu| I wrote a leaky relu plugin, but PRelu in `NvInferPlugin.h` can be used, see yolov3. |
57-
|yolo layer| yolo layer is implemented as a plugin, see yolov3. |
58+
|yolo layer v1| yolo layer is implemented as a plugin, see yolov3. |
59+
|yolo layer v2| three yolo layers implemented in one plugin, see yolov3-spp. |
5860
|upsample| replaced by a deconvolution layer, see yolov3. |
5961
|hsigmoid| hard sigmoid is implemented as a plugin, hsigmoid and hswish are used in mobilenetv3 |
6062

63+
## Speed Benchmark
64+
65+
| Models | Device | BatchSize | Mode | Input Shape(HxW) | FPS |
66+
|-|-|:-:|:-:|:-:|:-:|
67+
| yolov3(darknet53) | Xavier | 1 | FP16 | 320x320 | 55 |
68+
| yolov3-spp(darknet53) | GTX1080 | 1 | FP32 | 256x416 | 94 |
69+
70+
Help wanted, if you got speed results, please add an issue or PR.
71+
72+
Thanks @Kmarconi for yolov3(darknet53) speed test.

yolov3-spp/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# yolov3-spp
22

3-
The Pytorch implementation is [ultralytics/yolov3](https://github.com/ultralytics/yolov3)
3+
The Pytorch implementation is [ultralytics/yolov3](https://github.com/ultralytics/yolov3). It provides two trained weights of yolov3-spp, `yolov3-spp.pt` and `yolov3-spp-ultralytics.pt`(originally named `ultralytics68.pt`).
44

55
Following tricks are used in this yolov3-spp:
66

@@ -10,14 +10,14 @@ Following tricks are used in this yolov3-spp:
1010
## Excute:
1111

1212
```
13-
1. generate yolov3-spp_ultralytics68.wts from pytorch implementation with yolov3-spp.cfg and ultralytics68.pt
13+
1. generate yolov3-spp_ultralytics68.wts from pytorch implementation with yolov3-spp.cfg and yolov3-spp-ultralytics.pt
1414
1515
git clone https://github.com/wang-xinyu/tensorrtx.git
1616
git clone https://github.com/ultralytics/yolov3.git
17-
// download its weights 'ultralytics68.pt'
17+
// download its weights 'yolov3-spp-ultralytics.pt'
1818
cd yolov3
1919
cp ../tensorrtx/yolov3-spp/gen_wts.py .
20-
python gen_wts.py ultralytics68.pt
20+
python gen_wts.py yolov3-spp-ultralytics.pt
2121
// a file 'yolov3-spp_ultralytics68.wts' will be generated.
2222
// the master branch of yolov3 should work, if not, you can checkout 4ac60018f6e6c1e24b496485f126a660d9c793d8
2323

yolov3-spp/yololayer.cu

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,33 +195,34 @@ namespace nvinfer1
195195
//int out_row = input_col;
196196

197197
for (int k = 0; k < 3; ++k) {
198-
float *res_count = output;
199-
if(*res_count > 1000) break;
200-
int count = (int)atomicAdd(res_count, 1);
201-
char* data = (char * )res_count + sizeof(float) + count*sizeof(Detection);
202-
Detection* det = (Detection*)(data);
203-
204198
int class_id = 0;
205-
float max_prob = 0.0;
199+
float max_cls_prob = 0.0;
206200
for (int i = 5; i < info_len_i; ++i) {
207201
float p = Logist(input[input_col + k * info_len_i * total_grid + i * total_grid]);
208-
if (p > max_prob) {
209-
max_prob = p;
202+
if (p > max_cls_prob) {
203+
max_cls_prob = p;
210204
class_id = i - 5;
211205
}
212206
}
207+
float box_prob = Logist(input[input_col + k * info_len_i * total_grid + 4 * total_grid]);
208+
if (max_cls_prob < 0.1 || box_prob < 0.1) continue;
209+
210+
float *res_count = output;
211+
int count = (int)atomicAdd(res_count, 1);
212+
char* data = (char * )res_count + sizeof(float) + count*sizeof(Detection);
213+
Detection* det = (Detection*)(data);
213214

214215
int row = idx / yoloWidth;
215216
int col = idx % yoloWidth;
216217

217-
//Location
218+
//Location
218219
det->bbox[0] = (col + Logist(input[input_col + k * info_len_i * total_grid + 0 * total_grid])) * INPUT_W / yoloWidth;
219220
det->bbox[1] = (row + Logist(input[input_col + k * info_len_i * total_grid + 1 * total_grid])) * INPUT_H / yoloHeight;
220221
det->bbox[2] = exp(input[input_col + k * info_len_i * total_grid + 2 * total_grid]) * anchors[2*k];
221222
det->bbox[3] = exp(input[input_col + k * info_len_i * total_grid + 3 * total_grid]) * anchors[2*k + 1];
222-
det->det_confidence = Logist(input[input_col + k * info_len_i * total_grid + 4 * total_grid]);
223-
det->class_id = class_id;
224-
det->class_confidence = max_prob;
223+
det->det_confidence = box_prob;
224+
det->class_id = class_id;
225+
det->class_confidence = max_cls_prob;
225226
}
226227
}
227228

@@ -247,9 +248,9 @@ namespace nvinfer1
247248
numElem = yolo.width*yolo.height*batchSize;
248249
if (numElem < 256)
249250
mThreadCount = numElem;
250-
CUDA_CHECK(cudaMemcpy(devAnchor, yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
251+
CUDA_CHECK(cudaMemcpy(devAnchor, yolo.anchors, AnchorLen, cudaMemcpyHostToDevice));
251252
CalDetection<<< (yolo.width*yolo.height*batchSize + mThreadCount - 1) / mThreadCount, mThreadCount>>>
252-
(inputs[i],output, numElem, yolo.width, yolo.height, (float *)devAnchor, mClassCount ,outputElem);
253+
(inputs[i],output, numElem, yolo.width, yolo.height, (float *)devAnchor, mClassCount ,outputElem);
253254
}
254255

255256
CUDA_CHECK(cudaFree(devAnchor));

yolov3-spp/yolov3-spp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using namespace Yolo;
2222
// stuff we know about the network and the input/output blobs
2323
static const int INPUT_H = 256;
2424
static const int INPUT_W = 416;
25-
static const int OUTPUT_SIZE = 1000 * 7 + 1;
25+
static const int OUTPUT_SIZE = 1000 * 7 + 1; // we assume the yololayer outputs no more than 1000 boxes that conf >= 0.1
2626
const char* INPUT_BLOB_NAME = "data";
2727
const char* OUTPUT_BLOB_NAME = "prob";
2828
static Logger gLogger;
@@ -96,7 +96,7 @@ bool cmp(Detection& a, Detection& b) {
9696

9797
void nms(std::vector<Detection>& res, float *output, float nms_thresh = 0.4) {
9898
std::map<float, std::vector<Detection>> m;
99-
for (int i = 0; i < output[0]; i++) {
99+
for (int i = 0; i < output[0] && i < 1000; i++) {
100100
if (output[1 + 7 * i + 4] <= 0.5) continue;
101101
Detection det;
102102
memcpy(&det, &output[1 + 7 * i], 7 * sizeof(float));
@@ -537,6 +537,8 @@ int main(int argc, char** argv) {
537537
doInference(*context, data, prob, 1);
538538
std::vector<Detection> res;
539539
nms(res, prob);
540+
auto end = std::chrono::system_clock::now();
541+
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
540542
for (int i=0; i<20; i++) {
541543
std::cout << prob[i] << ",";
542544
}
@@ -551,8 +553,6 @@ int main(int argc, char** argv) {
551553
cv::rectangle(img, r, cv::Scalar(0x27, 0xC1, 0x36), 2);
552554
cv::putText(img, std::to_string((int)res[j].class_id), cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 1.2, cv::Scalar(0xFF, 0xFF, 0xFF), 2);
553555
}
554-
auto end = std::chrono::system_clock::now();
555-
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
556556
cv::imwrite("_" + f, img);
557557
}
558558

0 commit comments

Comments
 (0)