|
17 | 17 | #define DEVICE 0 // GPU id |
18 | 18 |
|
19 | 19 | using namespace nvinfer1; |
20 | | -using namespace Yolo; |
21 | 20 |
|
22 | 21 | // stuff we know about the network and the input/output blobs |
23 | | -static const int INPUT_H = 256; |
24 | | -static const int INPUT_W = 416; |
| 22 | +static const int INPUT_H = Yolo::INPUT_H; |
| 23 | +static const int INPUT_W = Yolo::INPUT_W; |
25 | 24 | static const int OUTPUT_SIZE = 1000 * 7 + 1; // we assume the yololayer outputs no more than 1000 boxes that conf >= 0.1 |
26 | 25 | const char* INPUT_BLOB_NAME = "data"; |
27 | 26 | const char* OUTPUT_BLOB_NAME = "prob"; |
@@ -90,17 +89,17 @@ float iou(float lbox[4], float rbox[4]) { |
90 | 89 | return interBoxS/(lbox[2]*lbox[3] + rbox[2]*rbox[3] -interBoxS); |
91 | 90 | } |
92 | 91 |
|
93 | | -bool cmp(Detection& a, Detection& b) { |
| 92 | +bool cmp(Yolo::Detection& a, Yolo::Detection& b) { |
94 | 93 | return a.det_confidence > b.det_confidence; |
95 | 94 | } |
96 | 95 |
|
97 | | -void nms(std::vector<Detection>& res, float *output, float nms_thresh = 0.4) { |
98 | | - std::map<float, std::vector<Detection>> m; |
| 96 | +void nms(std::vector<Yolo::Detection>& res, float *output, float nms_thresh = 0.4) { |
| 97 | + std::map<float, std::vector<Yolo::Detection>> m; |
99 | 98 | for (int i = 0; i < output[0] && i < 1000; i++) { |
100 | 99 | if (output[1 + 7 * i + 4] <= 0.5) continue; |
101 | | - Detection det; |
| 100 | + Yolo::Detection det; |
102 | 101 | memcpy(&det, &output[1 + 7 * i], 7 * sizeof(float)); |
103 | | - if (m.count(det.class_id) == 0) m.emplace(det.class_id, std::vector<Detection>()); |
| 102 | + if (m.count(det.class_id) == 0) m.emplace(det.class_id, std::vector<Yolo::Detection>()); |
104 | 103 | m[det.class_id].push_back(det); |
105 | 104 | } |
106 | 105 | for (auto it = m.begin(); it != m.end(); it++) { |
@@ -535,7 +534,7 @@ int main(int argc, char** argv) { |
535 | 534 | // Run inference |
536 | 535 | auto start = std::chrono::system_clock::now(); |
537 | 536 | doInference(*context, data, prob, 1); |
538 | | - std::vector<Detection> res; |
| 537 | + std::vector<Yolo::Detection> res; |
539 | 538 | nms(res, prob); |
540 | 539 | auto end = std::chrono::system_clock::now(); |
541 | 540 | std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl; |
|
0 commit comments