|  | 
|  | 1 | +#include "decode.h" | 
|  | 2 | +#include "stdio.h" | 
|  | 3 | + | 
|  | 4 | +namespace nvinfer1 | 
|  | 5 | +{ | 
|  | 6 | +    DecodePlugin::DecodePlugin(const int cudaThread):thread_count_(cudaThread) | 
|  | 7 | +    { | 
|  | 8 | +    } | 
|  | 9 | +     | 
|  | 10 | +    DecodePlugin::~DecodePlugin() | 
|  | 11 | +    { | 
|  | 12 | +    } | 
|  | 13 | +     | 
|  | 14 | +    // create the plugin at runtime from a byte stream | 
|  | 15 | +    DecodePlugin::DecodePlugin(const void* data, size_t length) | 
|  | 16 | +    { | 
|  | 17 | +    } | 
|  | 18 | + | 
|  | 19 | +    void DecodePlugin::serialize(void* buffer) | 
|  | 20 | +    { | 
|  | 21 | +    } | 
|  | 22 | +     | 
|  | 23 | +    size_t DecodePlugin::getSerializationSize() | 
|  | 24 | +    {   | 
|  | 25 | +        return 0; | 
|  | 26 | +    } | 
|  | 27 | + | 
|  | 28 | +    int DecodePlugin::initialize() | 
|  | 29 | +    {  | 
|  | 30 | +        return 0; | 
|  | 31 | +    } | 
|  | 32 | +     | 
|  | 33 | +    Dims DecodePlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) | 
|  | 34 | +    { | 
|  | 35 | +        //output the result to channel | 
|  | 36 | +        int totalCount = 0; | 
|  | 37 | +        totalCount += input_h_ / 8 * input_w_ / 8 * 2 * sizeof(decodeplugin::Detection) / sizeof(float); | 
|  | 38 | +        totalCount += input_h_ / 16 * input_w_ / 16 * 2 * sizeof(decodeplugin::Detection) / sizeof(float); | 
|  | 39 | +        totalCount += input_h_ / 32 * input_w_ / 32 * 2 * sizeof(decodeplugin::Detection) / sizeof(float); | 
|  | 40 | + | 
|  | 41 | +        return Dims3(totalCount + 1, 1, 1); | 
|  | 42 | +    } | 
|  | 43 | + | 
|  | 44 | +    __device__ float Logist(float data){ return 1./(1. + exp(-data)); }; | 
|  | 45 | + | 
|  | 46 | +    __global__ void CalDetection(const float *input, float *output, int num_elem, int input_h, int input_w, int step, int anchor) { | 
|  | 47 | +  | 
|  | 48 | +        int idx = threadIdx.x + blockDim.x * blockIdx.x; | 
|  | 49 | +        if (idx >= num_elem) return; | 
|  | 50 | + | 
|  | 51 | +        int h = input_h / step; | 
|  | 52 | +        int w = input_w / step; | 
|  | 53 | +        int y = idx / w; | 
|  | 54 | +        int x = idx % w; | 
|  | 55 | +        const float *bbox_reg = &input[0]; | 
|  | 56 | +        const float *cls_reg = &input[2 * 4 * num_elem]; | 
|  | 57 | +        const float *lmk_reg = &input[2 * 4 * num_elem + 2 * 2 * num_elem]; | 
|  | 58 | + | 
|  | 59 | +        for (int k = 0; k < 2; ++k) { | 
|  | 60 | +            float conf1 = cls_reg[idx + k * num_elem * 2]; | 
|  | 61 | +            float conf2 = cls_reg[idx + k * num_elem * 2 + num_elem]; | 
|  | 62 | +            conf2 = exp(conf2) / (exp(conf1) + exp(conf2)); | 
|  | 63 | +            if (conf2 <= 0.002) continue; | 
|  | 64 | + | 
|  | 65 | +            float *res_count = output; | 
|  | 66 | +            int count = (int)atomicAdd(res_count, 1); | 
|  | 67 | +            char* data = (char *)res_count + sizeof(float) + count * sizeof(decodeplugin::Detection); | 
|  | 68 | +            decodeplugin::Detection* det = (decodeplugin::Detection*)(data); | 
|  | 69 | + | 
|  | 70 | +            float prior[4]; | 
|  | 71 | +            prior[0] = ((float)x + 0.5) / w; | 
|  | 72 | +            prior[1] = ((float)y + 0.5) / h; | 
|  | 73 | +            prior[2] = (float)anchor / input_w; | 
|  | 74 | +            prior[3] = (float)anchor / input_h; | 
|  | 75 | +            printf("prior0, %f\n", prior[0]); | 
|  | 76 | +            printf("bbox0, %f\n", bbox_reg[idx + k * num_elem * 4]); | 
|  | 77 | + | 
|  | 78 | +            //Location | 
|  | 79 | +            det->bbox[0] = prior[0] + bbox_reg[idx + k * num_elem * 4] * 0.1 * prior[2]; | 
|  | 80 | +            det->bbox[1] = prior[1] + bbox_reg[idx + k * num_elem * 4 + num_elem] * 0.1 * prior[3]; | 
|  | 81 | +            det->bbox[2] = prior[2] * exp(bbox_reg[idx + k * num_elem * 4 + num_elem * 2] * 0.2); | 
|  | 82 | +            det->bbox[3] = prior[3] * exp(bbox_reg[idx + k * num_elem * 4 + num_elem * 3] * 0.2); | 
|  | 83 | +            det->bbox[0] -= det->bbox[2] / 2; | 
|  | 84 | +            det->bbox[1] -= det->bbox[3] / 2; | 
|  | 85 | +            det->bbox[2] += det->bbox[0]; | 
|  | 86 | +            det->bbox[3] += det->bbox[1]; | 
|  | 87 | +            det->bbox[0] *= input_w; | 
|  | 88 | +            det->bbox[1] *= input_h; | 
|  | 89 | +            det->bbox[2] *= input_w; | 
|  | 90 | +            det->bbox[3] *= input_h; | 
|  | 91 | +            det->class_confidence = conf2; | 
|  | 92 | +            anchor *= 2; | 
|  | 93 | +        } | 
|  | 94 | +    } | 
|  | 95 | +    | 
|  | 96 | +    void DecodePlugin::forwardGpu(const float *const * inputs, float * output, cudaStream_t stream, int batchSize)  | 
|  | 97 | +    { | 
|  | 98 | +        int num_elem = 0; | 
|  | 99 | +        int base_step = 8; | 
|  | 100 | +        int base_anchor = 16; | 
|  | 101 | +        int thread_count; | 
|  | 102 | +        for (unsigned int i = 0; i < 3; ++i) | 
|  | 103 | +        { | 
|  | 104 | +            num_elem = input_h_ / base_step * input_w_ / base_step; | 
|  | 105 | +            thread_count = (num_elem < thread_count_) ? num_elem : thread_count_; | 
|  | 106 | +            CalDetection<<< (num_elem + thread_count - 1) / thread_count, thread_count>>> | 
|  | 107 | +                (inputs[i], output, num_elem, input_h_, input_w_, base_step, base_anchor); | 
|  | 108 | +            base_step *= 2; | 
|  | 109 | +            base_anchor *= 4; | 
|  | 110 | +        } | 
|  | 111 | + | 
|  | 112 | +    } | 
|  | 113 | + | 
|  | 114 | + | 
|  | 115 | +    int DecodePlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) | 
|  | 116 | +    { | 
|  | 117 | +        //assert(batchSize == 1); | 
|  | 118 | +        //GPU | 
|  | 119 | +        //CUDA_CHECK(cudaStreamSynchronize(stream)); | 
|  | 120 | +        forwardGpu((const float *const *)inputs,(float *)outputs[0],stream,batchSize); | 
|  | 121 | + | 
|  | 122 | +        return 0; | 
|  | 123 | +    }; | 
|  | 124 | + | 
|  | 125 | +} | 
0 commit comments