|
16 | 16 | namespace nvinfer1 { |
17 | 17 |
|
18 | 18 | __global__ void batched_nms_kernel( |
19 | | - const int num_per_thread, const float threshold, const int num_detections, |
| 19 | + const float threshold, const int num_detections, |
20 | 20 | const int *indices, float *scores, const float *classes, const float4 *boxes) { |
21 | 21 |
|
22 | 22 | // Go through detections by descending score |
23 | 23 | for (int m = 0; m < num_detections; m++) { |
24 | | - for (int n = 0; n < num_per_thread; n++) { |
25 | | - int i = threadIdx.x * num_per_thread + n; |
26 | | - if (i < num_detections && m < i && scores[m] > 0.0f) { |
27 | | - int idx = indices[i]; |
28 | | - int max_idx = indices[m]; |
29 | | - int icls = classes[idx]; |
30 | | - int mcls = classes[max_idx]; |
31 | | - if (mcls == icls) { |
32 | | - float4 ibox = boxes[idx]; |
33 | | - float4 mbox = boxes[max_idx]; |
34 | | - float x1 = max(ibox.x, mbox.x); |
35 | | - float y1 = max(ibox.y, mbox.y); |
36 | | - float x2 = min(ibox.z, mbox.z); |
37 | | - float y2 = min(ibox.w, mbox.w); |
38 | | - float w = max(0.0f, x2 - x1); |
39 | | - float h = max(0.0f, y2 - y1); |
40 | | - float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y); |
41 | | - float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y); |
42 | | - float inter = w * h; |
43 | | - float overlap = inter / (iarea + marea - inter); |
44 | | - if (overlap > threshold) { |
45 | | - scores[i] = 0.0f; |
46 | | - } |
| 24 | + int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 25 | + if (i < num_detections && m < i && scores[m] > 0.0f) { |
| 26 | + int idx = indices[i]; |
| 27 | + int max_idx = indices[m]; |
| 28 | + int icls = classes[idx]; |
| 29 | + int mcls = classes[max_idx]; |
| 30 | + if (mcls == icls) { |
| 31 | + float4 ibox = boxes[idx]; |
| 32 | + float4 mbox = boxes[max_idx]; |
| 33 | + float x1 = max(ibox.x, mbox.x); |
| 34 | + float y1 = max(ibox.y, mbox.y); |
| 35 | + float x2 = min(ibox.z, mbox.z); |
| 36 | + float y2 = min(ibox.w, mbox.w); |
| 37 | + float w = max(0.0f, x2 - x1); |
| 38 | + float h = max(0.0f, y2 - y1); |
| 39 | + float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y); |
| 40 | + float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y); |
| 41 | + float inter = w * h; |
| 42 | + float overlap = inter / (iarea + marea - inter); |
| 43 | + if (overlap > threshold) { |
| 44 | + scores[i] = 0.0f; |
47 | 45 | } |
48 | 46 | } |
49 | 47 | } |
@@ -104,7 +102,7 @@ int batchedNms(int batch_size, |
104 | 102 | // TODO: different device has differnet max threads |
105 | 103 | const int max_threads = 1024; |
106 | 104 | int num_per_thread = ceil(static_cast<float>(num_detections) / max_threads); |
107 | | - batched_nms_kernel << <1, max_threads, 0, stream >> > (num_per_thread, nms_thresh, num_detections, |
| 105 | + batched_nms_kernel << <num_per_thread, max_threads, 0, stream >> > (nms_thresh, num_detections, |
108 | 106 | indices_sorted, scores_sorted, in_classes, in_boxes); |
109 | 107 |
|
110 | 108 | // Re-sort with updated scores |
|
0 commit comments