@@ -34,22 +34,22 @@ namespace nvinfer1
3434    {
3535        // output the result to channel
3636        int  totalCount = 0 ;
37-         totalCount += input_h_  / 8  * input_w_  / 8  * 2  * sizeof (decodeplugin::Detection) / sizeof (float );
38-         totalCount += input_h_  / 16  * input_w_  / 16  * 2  * sizeof (decodeplugin::Detection) / sizeof (float );
39-         totalCount += input_h_  / 32  * input_w_  / 32  * 2  * sizeof (decodeplugin::Detection) / sizeof (float );
37+         totalCount += decodeplugin::INPUT_H  / 8  * decodeplugin::INPUT_W  / 8  * 2  * sizeof (decodeplugin::Detection) / sizeof (float );
38+         totalCount += decodeplugin::INPUT_H  / 16  * decodeplugin::INPUT_W  / 16  * 2  * sizeof (decodeplugin::Detection) / sizeof (float );
39+         totalCount += decodeplugin::INPUT_H  / 32  * decodeplugin::INPUT_W  / 32  * 2  * sizeof (decodeplugin::Detection) / sizeof (float );
4040
4141        return  Dims3 (totalCount + 1 , 1 , 1 );
4242    }
4343
4444    __device__  float  Logist (float  data){ return  1 ./(1 . + exp (-data)); };
4545
46-     __global__  void  CalDetection (const  float  *input, float  *output, int  num_elem, int  input_h,  int  input_w,  int   step, int  anchor) {
46+     __global__  void  CalDetection (const  float  *input, float  *output, int  num_elem, int  step, int  anchor) {
4747
4848        int  idx = threadIdx .x  + blockDim .x  * blockIdx .x ;
4949        if  (idx >= num_elem) return ;
5050
51-         int  h = input_h  / step;
52-         int  w = input_w  / step;
51+         int  h = decodeplugin::INPUT_H  / step;
52+         int  w = decodeplugin::INPUT_W  / step;
5353        int  y = idx / w;
5454        int  x = idx % w;
5555        const  float  *bbox_reg = &input[0 ];
@@ -60,7 +60,7 @@ namespace nvinfer1
6060            float  conf1 = cls_reg[idx + k * num_elem * 2 ];
6161            float  conf2 = cls_reg[idx + k * num_elem * 2  + num_elem];
6262            conf2 = exp (conf2) / (exp (conf1) + exp (conf2));
63-             if  (conf2 <= 0.002  ) continue ;
63+             if  (conf2 <= 0.02  ) continue ;
6464
6565            float  *res_count = output;
6666            int  count = (int )atomicAdd (res_count, 1 );
@@ -70,10 +70,8 @@ namespace nvinfer1
7070            float  prior[4 ];
7171            prior[0 ] = ((float )x + 0.5 ) / w;
7272            prior[1 ] = ((float )y + 0.5 ) / h;
73-             prior[2 ] = (float )anchor / input_w;
74-             prior[3 ] = (float )anchor / input_h;
75-             printf (" prior0, %f\n "  , prior[0 ]);
76-             printf (" bbox0, %f\n "  , bbox_reg[idx + k * num_elem * 4 ]);
73+             prior[2 ] = (float )anchor * (k + 1 ) / decodeplugin::INPUT_W;
74+             prior[3 ] = (float )anchor * (k + 1 ) / decodeplugin::INPUT_H;
7775
7876            // Location
7977            det->bbox [0 ] = prior[0 ] + bbox_reg[idx + k * num_elem * 4 ] * 0.1  * prior[2 ];
@@ -84,12 +82,17 @@ namespace nvinfer1
8482            det->bbox [1 ] -= det->bbox [3 ] / 2 ;
8583            det->bbox [2 ] += det->bbox [0 ];
8684            det->bbox [3 ] += det->bbox [1 ];
87-             det->bbox [0 ] *= input_w ;
88-             det->bbox [1 ] *= input_h ;
89-             det->bbox [2 ] *= input_w ;
90-             det->bbox [3 ] *= input_h ;
85+             det->bbox [0 ] *= decodeplugin::INPUT_W ;
86+             det->bbox [1 ] *= decodeplugin::INPUT_H ;
87+             det->bbox [2 ] *= decodeplugin::INPUT_W ;
88+             det->bbox [3 ] *= decodeplugin::INPUT_H ;
9189            det->class_confidence  = conf2;
92-             anchor *= 2 ;
90+             for  (int  i = 0 ; i < 10 ; i += 2 ) {
91+                 det->landmark [i] = prior[0 ] + lmk_reg[idx + k * num_elem * 10  + num_elem * i] * 0.1  * prior[2 ];
92+                 det->landmark [i+1 ] = prior[1 ] + lmk_reg[idx + k * num_elem * 10  + num_elem * (i + 1 )] * 0.1  * prior[3 ];
93+                 det->landmark [i] *= decodeplugin::INPUT_W;
94+                 det->landmark [i+1 ] *= decodeplugin::INPUT_H;
95+             }
9396        }
9497    }
9598
@@ -101,17 +104,15 @@ namespace nvinfer1
101104        int  thread_count;
102105        for  (unsigned  int  i = 0 ; i < 3 ; ++i)
103106        {
104-             num_elem = input_h_  / base_step * input_w_  / base_step;
107+             num_elem = decodeplugin::INPUT_H  / base_step * decodeplugin::INPUT_W  / base_step;
105108            thread_count = (num_elem < thread_count_) ? num_elem : thread_count_;
106109            CalDetection<<< (num_elem + thread_count - 1 ) / thread_count, thread_count>>> 
107-                 (inputs[i], output, num_elem, input_h_, input_w_,  base_step, base_anchor);
110+                 (inputs[i], output, num_elem, base_step, base_anchor);
108111            base_step *= 2 ;
109112            base_anchor *= 4 ;
110113        }
111- 
112114    }
113115
114- 
115116    int  DecodePlugin::enqueue (int  batchSize, const  void *const  * inputs, void ** outputs, void * workspace, cudaStream_t stream)
116117    {
117118        // assert(batchSize == 1);
0 commit comments