Skip to content

Commit 3ae17d4

Browse files
committed
retinaface update
1 parent 35d59c9 commit 3ae17d4

File tree

5 files changed

+122
-125
lines changed

5 files changed

+122
-125
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Following models are implemented, each one also has a readme inside.
4040
|[vgg](./vgg)| VGG 11-layer model |
4141
|[yolov3](./yolov3)| darknet-53, weights from yolov3 authors |
4242
|[yolov3-spp](./yolov3-spp)| darknet-53, weights from [ultralytics/yolov3](https://github.com/ultralytics/yolov3) |
43+
|[retinaface](./retinaface)| resnet-50, weights from [biubug6/Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) |
4344

4445
## Tricky Operations
4546

@@ -59,13 +60,15 @@ Some tricky operations encountered in these models, already solved, but might ha
5960
|yolo layer v2| three yolo layers implemented in one plugin, see yolov3-spp. |
6061
|upsample| replaced by a deconvolution layer, see yolov3. |
6162
|hsigmoid| hard sigmoid is implemented as a plugin, hsigmoid and hswish are used in mobilenetv3 |
63+
|retinaface output decode| implement a plugin to decode bbox, confidence and landmarks, see retinaface. |
6264

6365
## Speed Benchmark
6466

6567
| Models | Device | BatchSize | Mode | Input Shape(HxW) | FPS |
6668
|-|-|:-:|:-:|:-:|:-:|
67-
| yolov3(darknet53) | Xavier | 1 | FP16 | 320x320 | 55 |
68-
| yolov3-spp(darknet53) | GTX1080 | 1 | FP32 | 256x416 | 94 |
69+
| YOLOv3(darknet53) | Xavier | 1 | FP16 | 320x320 | 55 |
70+
| YOLOv3-spp(darknet53) | GTX1080 | 1 | FP32 | 256x416 | 94 |
71+
| RetinaFace(resnet50) | TX2 | 1 | FP16 | 384x640 | 15 |
6972

7073
Help wanted, if you got speed results, please add an issue or PR.
7174

retinaface/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
# RetinaFace
22

3-
still working in progress
3+
## Notice
4+
5+
- Only tested on TensorRT4.1.3
6+
7+

retinaface/decode.cu

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ namespace nvinfer1
3434
{
3535
//output the result to channel
3636
int totalCount = 0;
37-
totalCount += input_h_ / 8 * input_w_ / 8 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
38-
totalCount += input_h_ / 16 * input_w_ / 16 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
39-
totalCount += input_h_ / 32 * input_w_ / 32 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
37+
totalCount += decodeplugin::INPUT_H / 8 * decodeplugin::INPUT_W / 8 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
38+
totalCount += decodeplugin::INPUT_H / 16 * decodeplugin::INPUT_W / 16 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
39+
totalCount += decodeplugin::INPUT_H / 32 * decodeplugin::INPUT_W / 32 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
4040

4141
return Dims3(totalCount + 1, 1, 1);
4242
}
4343

4444
__device__ float Logist(float data){ return 1./(1. + exp(-data)); };
4545

46-
__global__ void CalDetection(const float *input, float *output, int num_elem, int input_h, int input_w, int step, int anchor) {
46+
__global__ void CalDetection(const float *input, float *output, int num_elem, int step, int anchor) {
4747

4848
int idx = threadIdx.x + blockDim.x * blockIdx.x;
4949
if (idx >= num_elem) return;
5050

51-
int h = input_h / step;
52-
int w = input_w / step;
51+
int h = decodeplugin::INPUT_H / step;
52+
int w = decodeplugin::INPUT_W / step;
5353
int y = idx / w;
5454
int x = idx % w;
5555
const float *bbox_reg = &input[0];
@@ -60,7 +60,7 @@ namespace nvinfer1
6060
float conf1 = cls_reg[idx + k * num_elem * 2];
6161
float conf2 = cls_reg[idx + k * num_elem * 2 + num_elem];
6262
conf2 = exp(conf2) / (exp(conf1) + exp(conf2));
63-
if (conf2 <= 0.002) continue;
63+
if (conf2 <= 0.02) continue;
6464

6565
float *res_count = output;
6666
int count = (int)atomicAdd(res_count, 1);
@@ -70,10 +70,8 @@ namespace nvinfer1
7070
float prior[4];
7171
prior[0] = ((float)x + 0.5) / w;
7272
prior[1] = ((float)y + 0.5) / h;
73-
prior[2] = (float)anchor / input_w;
74-
prior[3] = (float)anchor / input_h;
75-
printf("prior0, %f\n", prior[0]);
76-
printf("bbox0, %f\n", bbox_reg[idx + k * num_elem * 4]);
73+
prior[2] = (float)anchor * (k + 1) / decodeplugin::INPUT_W;
74+
prior[3] = (float)anchor * (k + 1) / decodeplugin::INPUT_H;
7775

7876
//Location
7977
det->bbox[0] = prior[0] + bbox_reg[idx + k * num_elem * 4] * 0.1 * prior[2];
@@ -84,12 +82,17 @@ namespace nvinfer1
8482
det->bbox[1] -= det->bbox[3] / 2;
8583
det->bbox[2] += det->bbox[0];
8684
det->bbox[3] += det->bbox[1];
87-
det->bbox[0] *= input_w;
88-
det->bbox[1] *= input_h;
89-
det->bbox[2] *= input_w;
90-
det->bbox[3] *= input_h;
85+
det->bbox[0] *= decodeplugin::INPUT_W;
86+
det->bbox[1] *= decodeplugin::INPUT_H;
87+
det->bbox[2] *= decodeplugin::INPUT_W;
88+
det->bbox[3] *= decodeplugin::INPUT_H;
9189
det->class_confidence = conf2;
92-
anchor *= 2;
90+
for (int i = 0; i < 10; i += 2) {
91+
det->landmark[i] = prior[0] + lmk_reg[idx + k * num_elem * 10 + num_elem * i] * 0.1 * prior[2];
92+
det->landmark[i+1] = prior[1] + lmk_reg[idx + k * num_elem * 10 + num_elem * (i + 1)] * 0.1 * prior[3];
93+
det->landmark[i] *= decodeplugin::INPUT_W;
94+
det->landmark[i+1] *= decodeplugin::INPUT_H;
95+
}
9396
}
9497
}
9598

@@ -101,17 +104,15 @@ namespace nvinfer1
101104
int thread_count;
102105
for (unsigned int i = 0; i < 3; ++i)
103106
{
104-
num_elem = input_h_ / base_step * input_w_ / base_step;
107+
num_elem = decodeplugin::INPUT_H / base_step * decodeplugin::INPUT_W / base_step;
105108
thread_count = (num_elem < thread_count_) ? num_elem : thread_count_;
106109
CalDetection<<< (num_elem + thread_count - 1) / thread_count, thread_count>>>
107-
(inputs[i], output, num_elem, input_h_, input_w_, base_step, base_anchor);
110+
(inputs[i], output, num_elem, base_step, base_anchor);
108111
base_step *= 2;
109112
base_anchor *= 4;
110113
}
111-
112114
}
113115

114-
115116
int DecodePlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
116117
{
117118
//assert(batchSize == 1);

retinaface/decode.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
namespace decodeplugin
77
{
88
struct alignas(float) Detection{
9-
//x y w h
10-
float bbox[4];
9+
float bbox[4]; //x1 y1 x2 y2
1110
float class_confidence;
1211
float landmark[10];
1312
};
13+
static const int INPUT_H = 384;
14+
static const int INPUT_W = 640;
1415
}
1516

1617

@@ -52,8 +53,6 @@ namespace nvinfer1
5253
void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
5354

5455
private:
55-
const int input_h_ = 384;
56-
const int input_w_ = 640;
5756
int thread_count_ = 256;
5857
};
5958
};

0 commit comments

Comments
 (0)