@@ -34,22 +34,22 @@ namespace nvinfer1
3434 {
3535 // output the result to channel
3636 int totalCount = 0 ;
37- totalCount += input_h_ / 8 * input_w_ / 8 * 2 * sizeof (decodeplugin::Detection) / sizeof (float );
38- totalCount += input_h_ / 16 * input_w_ / 16 * 2 * sizeof (decodeplugin::Detection) / sizeof (float );
39- totalCount += input_h_ / 32 * input_w_ / 32 * 2 * sizeof (decodeplugin::Detection) / sizeof (float );
37+ totalCount += decodeplugin::INPUT_H / 8 * decodeplugin::INPUT_W / 8 * 2 * sizeof (decodeplugin::Detection) / sizeof (float );
38+ totalCount += decodeplugin::INPUT_H / 16 * decodeplugin::INPUT_W / 16 * 2 * sizeof (decodeplugin::Detection) / sizeof (float );
39+ totalCount += decodeplugin::INPUT_H / 32 * decodeplugin::INPUT_W / 32 * 2 * sizeof (decodeplugin::Detection) / sizeof (float );
4040
4141 return Dims3 (totalCount + 1 , 1 , 1 );
4242 }
4343
4444 __device__ float Logist (float data){ return 1 ./(1 . + exp (-data)); };
4545
46- __global__ void CalDetection (const float *input, float *output, int num_elem, int input_h, int input_w, int step, int anchor) {
46+ __global__ void CalDetection (const float *input, float *output, int num_elem, int step, int anchor) {
4747
4848 int idx = threadIdx .x + blockDim .x * blockIdx .x ;
4949 if (idx >= num_elem) return ;
5050
51- int h = input_h / step;
52- int w = input_w / step;
51+ int h = decodeplugin::INPUT_H / step;
52+ int w = decodeplugin::INPUT_W / step;
5353 int y = idx / w;
5454 int x = idx % w;
5555 const float *bbox_reg = &input[0 ];
@@ -60,7 +60,7 @@ namespace nvinfer1
6060 float conf1 = cls_reg[idx + k * num_elem * 2 ];
6161 float conf2 = cls_reg[idx + k * num_elem * 2 + num_elem];
6262 conf2 = exp (conf2) / (exp (conf1) + exp (conf2));
63- if (conf2 <= 0.002 ) continue ;
63+ if (conf2 <= 0.02 ) continue ;
6464
6565 float *res_count = output;
6666 int count = (int )atomicAdd (res_count, 1 );
@@ -70,10 +70,8 @@ namespace nvinfer1
7070 float prior[4 ];
7171 prior[0 ] = ((float )x + 0.5 ) / w;
7272 prior[1 ] = ((float )y + 0.5 ) / h;
73- prior[2 ] = (float )anchor / input_w;
74- prior[3 ] = (float )anchor / input_h;
75- printf (" prior0, %f\n " , prior[0 ]);
76- printf (" bbox0, %f\n " , bbox_reg[idx + k * num_elem * 4 ]);
73+ prior[2 ] = (float )anchor * (k + 1 ) / decodeplugin::INPUT_W;
74+ prior[3 ] = (float )anchor * (k + 1 ) / decodeplugin::INPUT_H;
7775
7876 // Location
7977 det->bbox [0 ] = prior[0 ] + bbox_reg[idx + k * num_elem * 4 ] * 0.1 * prior[2 ];
@@ -84,12 +82,17 @@ namespace nvinfer1
8482 det->bbox [1 ] -= det->bbox [3 ] / 2 ;
8583 det->bbox [2 ] += det->bbox [0 ];
8684 det->bbox [3 ] += det->bbox [1 ];
87- det->bbox [0 ] *= input_w ;
88- det->bbox [1 ] *= input_h ;
89- det->bbox [2 ] *= input_w ;
90- det->bbox [3 ] *= input_h ;
85+ det->bbox [0 ] *= decodeplugin::INPUT_W ;
86+ det->bbox [1 ] *= decodeplugin::INPUT_H ;
87+ det->bbox [2 ] *= decodeplugin::INPUT_W ;
88+ det->bbox [3 ] *= decodeplugin::INPUT_H ;
9189 det->class_confidence = conf2;
92- anchor *= 2 ;
90+ for (int i = 0 ; i < 10 ; i += 2 ) {
91+ det->landmark [i] = prior[0 ] + lmk_reg[idx + k * num_elem * 10 + num_elem * i] * 0.1 * prior[2 ];
92+ det->landmark [i+1 ] = prior[1 ] + lmk_reg[idx + k * num_elem * 10 + num_elem * (i + 1 )] * 0.1 * prior[3 ];
93+ det->landmark [i] *= decodeplugin::INPUT_W;
94+ det->landmark [i+1 ] *= decodeplugin::INPUT_H;
95+ }
9396 }
9497 }
9598
@@ -101,17 +104,15 @@ namespace nvinfer1
101104 int thread_count;
102105 for (unsigned int i = 0 ; i < 3 ; ++i)
103106 {
104- num_elem = input_h_ / base_step * input_w_ / base_step;
107+ num_elem = decodeplugin::INPUT_H / base_step * decodeplugin::INPUT_W / base_step;
105108 thread_count = (num_elem < thread_count_) ? num_elem : thread_count_;
106109 CalDetection<<< (num_elem + thread_count - 1 ) / thread_count, thread_count>>>
107- (inputs[i], output, num_elem, input_h_, input_w_, base_step, base_anchor);
110+ (inputs[i], output, num_elem, base_step, base_anchor);
108111 base_step *= 2 ;
109112 base_anchor *= 4 ;
110113 }
111-
112114 }
113115
114-
115116 int DecodePlugin::enqueue (int batchSize, const void *const * inputs, void ** outputs, void * workspace, cudaStream_t stream)
116117 {
117118 // assert(batchSize == 1);
0 commit comments