Skip to content

Commit 35d59c9

Browse files
committed
retinaface ssh, decode plugin
1 parent a472fbb commit 35d59c9

File tree

6 files changed

+324
-55
lines changed

6 files changed

+324
-55
lines changed

retinaface/CMakeLists.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@ link_directories(/usr/local/cuda-9.0/targets/aarch64-linux/lib)
1818

1919
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED")
2020

21-
#cuda_add_library(leaky ${PROJECT_SOURCE_DIR}/leaky.cu)
22-
#cuda_add_library(yololayer ${PROJECT_SOURCE_DIR}/yololayer.cu)
21+
cuda_add_library(decodeplugin SHARED ${PROJECT_SOURCE_DIR}/decode.cu)
2322

2423
find_package(OpenCV)
2524
include_directories(OpenCV_INCLUDE_DIRS)
2625

27-
add_executable(retina_50 ${PROJECT_SOURCE_DIR}/retina_r50.cpp)
26+
add_executable(retina_50 ${PROJECT_SOURCE_DIR}/plugin_factory.cpp ${PROJECT_SOURCE_DIR}/retina_r50.cpp)
2827
target_link_libraries(retina_50 nvinfer nvinfer_plugin)
2928
target_link_libraries(retina_50 cudart)
30-
#target_link_libraries(retina yololayer)
29+
target_link_libraries(retina_50 decodeplugin)
3130
target_link_libraries(retina_50 ${OpenCV_LIBRARIES})
3231

3332
add_definitions(-O2 -pthread)

retinaface/decode.cu

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include "decode.h"
2+
#include "stdio.h"
3+
4+
namespace nvinfer1
5+
{
6+
DecodePlugin::DecodePlugin(const int cudaThread):thread_count_(cudaThread)
7+
{
8+
}
9+
10+
DecodePlugin::~DecodePlugin()
11+
{
12+
}
13+
14+
// create the plugin at runtime from a byte stream
15+
DecodePlugin::DecodePlugin(const void* data, size_t length)
16+
{
17+
}
18+
19+
void DecodePlugin::serialize(void* buffer)
20+
{
21+
}
22+
23+
size_t DecodePlugin::getSerializationSize()
24+
{
25+
return 0;
26+
}
27+
28+
int DecodePlugin::initialize()
29+
{
30+
return 0;
31+
}
32+
33+
Dims DecodePlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
34+
{
35+
//output the result to channel
36+
int totalCount = 0;
37+
totalCount += input_h_ / 8 * input_w_ / 8 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
38+
totalCount += input_h_ / 16 * input_w_ / 16 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
39+
totalCount += input_h_ / 32 * input_w_ / 32 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
40+
41+
return Dims3(totalCount + 1, 1, 1);
42+
}
43+
44+
__device__ float Logist(float data){ return 1./(1. + exp(-data)); };
45+
46+
__global__ void CalDetection(const float *input, float *output, int num_elem, int input_h, int input_w, int step, int anchor) {
47+
48+
int idx = threadIdx.x + blockDim.x * blockIdx.x;
49+
if (idx >= num_elem) return;
50+
51+
int h = input_h / step;
52+
int w = input_w / step;
53+
int y = idx / w;
54+
int x = idx % w;
55+
const float *bbox_reg = &input[0];
56+
const float *cls_reg = &input[2 * 4 * num_elem];
57+
const float *lmk_reg = &input[2 * 4 * num_elem + 2 * 2 * num_elem];
58+
59+
for (int k = 0; k < 2; ++k) {
60+
float conf1 = cls_reg[idx + k * num_elem * 2];
61+
float conf2 = cls_reg[idx + k * num_elem * 2 + num_elem];
62+
conf2 = exp(conf2) / (exp(conf1) + exp(conf2));
63+
if (conf2 <= 0.002) continue;
64+
65+
float *res_count = output;
66+
int count = (int)atomicAdd(res_count, 1);
67+
char* data = (char *)res_count + sizeof(float) + count * sizeof(decodeplugin::Detection);
68+
decodeplugin::Detection* det = (decodeplugin::Detection*)(data);
69+
70+
float prior[4];
71+
prior[0] = ((float)x + 0.5) / w;
72+
prior[1] = ((float)y + 0.5) / h;
73+
prior[2] = (float)anchor / input_w;
74+
prior[3] = (float)anchor / input_h;
75+
printf("prior0, %f\n", prior[0]);
76+
printf("bbox0, %f\n", bbox_reg[idx + k * num_elem * 4]);
77+
78+
//Location
79+
det->bbox[0] = prior[0] + bbox_reg[idx + k * num_elem * 4] * 0.1 * prior[2];
80+
det->bbox[1] = prior[1] + bbox_reg[idx + k * num_elem * 4 + num_elem] * 0.1 * prior[3];
81+
det->bbox[2] = prior[2] * exp(bbox_reg[idx + k * num_elem * 4 + num_elem * 2] * 0.2);
82+
det->bbox[3] = prior[3] * exp(bbox_reg[idx + k * num_elem * 4 + num_elem * 3] * 0.2);
83+
det->bbox[0] -= det->bbox[2] / 2;
84+
det->bbox[1] -= det->bbox[3] / 2;
85+
det->bbox[2] += det->bbox[0];
86+
det->bbox[3] += det->bbox[1];
87+
det->bbox[0] *= input_w;
88+
det->bbox[1] *= input_h;
89+
det->bbox[2] *= input_w;
90+
det->bbox[3] *= input_h;
91+
det->class_confidence = conf2;
92+
anchor *= 2;
93+
}
94+
}
95+
96+
void DecodePlugin::forwardGpu(const float *const * inputs, float * output, cudaStream_t stream, int batchSize)
97+
{
98+
int num_elem = 0;
99+
int base_step = 8;
100+
int base_anchor = 16;
101+
int thread_count;
102+
for (unsigned int i = 0; i < 3; ++i)
103+
{
104+
num_elem = input_h_ / base_step * input_w_ / base_step;
105+
thread_count = (num_elem < thread_count_) ? num_elem : thread_count_;
106+
CalDetection<<< (num_elem + thread_count - 1) / thread_count, thread_count>>>
107+
(inputs[i], output, num_elem, input_h_, input_w_, base_step, base_anchor);
108+
base_step *= 2;
109+
base_anchor *= 4;
110+
}
111+
112+
}
113+
114+
115+
int DecodePlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
116+
{
117+
//assert(batchSize == 1);
118+
//GPU
119+
//CUDA_CHECK(cudaStreamSynchronize(stream));
120+
forwardGpu((const float *const *)inputs,(float *)outputs[0],stream,batchSize);
121+
122+
return 0;
123+
};
124+
125+
}

retinaface/decode.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#ifndef _DECODE_CU_H
2+
#define _DECODE_CU_H
3+
4+
#include "NvInfer.h"
5+
6+
namespace decodeplugin
7+
{
8+
struct alignas(float) Detection{
9+
//x y w h
10+
float bbox[4];
11+
float class_confidence;
12+
float landmark[10];
13+
};
14+
}
15+
16+
17+
namespace nvinfer1
18+
{
19+
class DecodePlugin: public IPluginExt
20+
{
21+
public:
22+
explicit DecodePlugin(const int cudaThread = 256);
23+
DecodePlugin(const void* data, size_t length);
24+
25+
~DecodePlugin();
26+
27+
int getNbOutputs() const override
28+
{
29+
return 1;
30+
}
31+
32+
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
33+
34+
bool supportsFormat(DataType type, PluginFormat format) const override {
35+
return type == DataType::kFLOAT && format == PluginFormat::kNCHW;
36+
}
37+
38+
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {};
39+
40+
int initialize() override;
41+
42+
virtual void terminate() override {};
43+
44+
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
45+
46+
virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
47+
48+
virtual size_t getSerializationSize() override;
49+
50+
virtual void serialize(void* buffer) override;
51+
52+
void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
53+
54+
private:
55+
const int input_h_ = 384;
56+
const int input_w_ = 640;
57+
int thread_count_ = 256;
58+
};
59+
};
60+
61+
#endif

retinaface/plugin_factory.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "plugin_factory.h"
2+
#include "NvInferPlugin.h"
3+
#include "decode.h"
4+
#include "common.h"
5+
6+
using namespace nvinfer1;
7+
using nvinfer1::PluginFactory;
8+
9+
IPlugin* PluginFactory::createPlugin(const char* layerName, const void* serialData, size_t serialLength) {
10+
IPlugin *plugin = nullptr;
11+
if (strstr(layerName, "leaky") != NULL) {
12+
plugin = plugin::createPReLUPlugin(serialData, serialLength);
13+
} else if (strstr(layerName, "decode") != NULL) {
14+
plugin = new DecodePlugin(serialData, serialLength);
15+
}
16+
return plugin;
17+
}

retinaface/plugin_factory.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef MY_PLUGIN_FACTORY_H
2+
#define MY_PLUGIN_FACTORY_H
3+
#include <NvInfer.h>
4+
5+
namespace nvinfer1 {
6+
class PluginFactory : public IPluginFactory {
7+
public:
8+
IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override;
9+
};
10+
11+
}
12+
#endif

0 commit comments

Comments
 (0)