1818#define DEVICE 0 // GPU id
1919#define NMS_THRESH 0.4
2020#define BBOX_CONF_THRESH 0.5
21+ #define BATCH_SIZE 1
2122
2223using namespace nvinfer1 ;
2324
2425// stuff we know about the network and the input/output blobs
2526static const int INPUT_H = Yolo::INPUT_H;
2627static const int INPUT_W = Yolo::INPUT_W;
27- static const int OUTPUT_SIZE = 1000 * 7 + 1 ; // we assume the yololayer outputs no more than 1000 boxes that conf >= 0.1
28+ static const int DETECTION_SIZE = sizeof (Yolo::Detection) / sizeof (float );
29+ static const int OUTPUT_SIZE = Yolo::MAX_OUTPUT_BBOX_COUNT * DETECTION_SIZE + 1 ; // we assume the yololayer outputs no more than MAX_OUTPUT_BBOX_COUNT boxes that conf >= 0.1
2830const char * INPUT_BLOB_NAME = " data" ;
2931const char * OUTPUT_BLOB_NAME = " prob" ;
3032static Logger gLogger ;
@@ -98,10 +100,10 @@ bool cmp(Yolo::Detection& a, Yolo::Detection& b) {
98100
99101void nms (std::vector<Yolo::Detection>& res, float *output, float nms_thresh = NMS_THRESH) {
100102 std::map<float , std::vector<Yolo::Detection>> m;
101- for (int i = 0 ; i < output[0 ] && i < 1000 ; i++) {
102- if (output[1 + 7 * i + 4 ] <= BBOX_CONF_THRESH) continue ;
103+ for (int i = 0 ; i < output[0 ] && i < Yolo::MAX_OUTPUT_BBOX_COUNT ; i++) {
104+ if (output[1 + DETECTION_SIZE * i + 4 ] <= BBOX_CONF_THRESH) continue ;
103105 Yolo::Detection det;
104- memcpy (&det, &output[1 + 7 * i], 7 * sizeof (float ));
106+ memcpy (&det, &output[1 + DETECTION_SIZE * i], DETECTION_SIZE * sizeof (float ));
105107 if (m.count (det.class_id ) == 0 ) m.emplace (det.class_id , std::vector<Yolo::Detection>());
106108 m[det.class_id ].push_back (det);
107109 }
@@ -582,7 +584,7 @@ int main(int argc, char** argv) {
582584
583585 if (argc == 2 && std::string (argv[1 ]) == " -s" ) {
584586 IHostMemory* modelStream{nullptr };
585- APIToModel (1 , &modelStream);
587+ APIToModel (BATCH_SIZE , &modelStream);
586588 assert (modelStream != nullptr );
587589 std::ofstream p (" yolov4.engine" );
588590 if (!p) {
@@ -617,10 +619,10 @@ int main(int argc, char** argv) {
617619 }
618620
619621 // prepare input data ---------------------------
620- float data[3 * INPUT_H * INPUT_W];
622+ static float data[BATCH_SIZE * 3 * INPUT_H * INPUT_W];
621623 // for (int i = 0; i < 3 * INPUT_H * INPUT_W; i++)
622624 // data[i] = 1.0;
623- static float prob[OUTPUT_SIZE];
625+ static float prob[BATCH_SIZE * OUTPUT_SIZE];
624626 PluginFactory pf;
625627 IRuntime* runtime = createInferRuntime (gLogger );
626628 assert (runtime != nullptr );
@@ -630,37 +632,47 @@ int main(int argc, char** argv) {
630632 assert (context != nullptr );
631633
632634 int fcount = 0 ;
633- for (auto f: file_names) {
635+ for (int f = 0 ; f < file_names. size (); f++ ) {
634636 fcount++;
635- std::cout << fcount << " " << f << std::endl;
636- cv::Mat img = cv::imread (std::string (argv[2 ]) + " /" + f);
637- if (img.empty ()) continue ;
638- cv::Mat pr_img = preprocess_img (img);
639- for (int i = 0 ; i < INPUT_H * INPUT_W; i++) {
640- data[i] = pr_img.at <cv::Vec3b>(i)[2 ] / 255.0 ;
641- data[i + INPUT_H * INPUT_W] = pr_img.at <cv::Vec3b>(i)[1 ] / 255.0 ;
642- data[i + 2 * INPUT_H * INPUT_W] = pr_img.at <cv::Vec3b>(i)[0 ] / 255.0 ;
637+ if (fcount < BATCH_SIZE && f + 1 != file_names.size ()) continue ;
638+ for (int b = 0 ; b < fcount; b++) {
639+ cv::Mat img = cv::imread (std::string (argv[2 ]) + " /" + file_names[f - BATCH_SIZE + 1 + b]);
640+ if (img.empty ()) continue ;
641+ cv::Mat pr_img = preprocess_img (img);
642+ for (int i = 0 ; i < INPUT_H * INPUT_W; i++) {
643+ data[b * 3 * INPUT_H * INPUT_W + i] = pr_img.at <cv::Vec3b>(i)[2 ] / 255.0 ;
644+ data[b * 3 * INPUT_H * INPUT_W + i + INPUT_H * INPUT_W] = pr_img.at <cv::Vec3b>(i)[1 ] / 255.0 ;
645+ data[b * 3 * INPUT_H * INPUT_W + i + 2 * INPUT_H * INPUT_W] = pr_img.at <cv::Vec3b>(i)[0 ] / 255.0 ;
646+ }
643647 }
644648
645649 // Run inference
646650 auto start = std::chrono::system_clock::now ();
647- doInference (*context, data, prob, 1 );
648- std::vector<Yolo::Detection> res;
649- nms (res, prob);
651+ doInference (*context, data, prob, BATCH_SIZE);
652+ std::vector<std::vector<Yolo::Detection>> batch_res (fcount);
653+ for (int b = 0 ; b < fcount; b++) {
654+ auto & res = batch_res[b];
655+ nms (res, &prob[b * OUTPUT_SIZE]);
656+ }
650657 auto end = std::chrono::system_clock::now ();
651658 std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count () << " ms" << std::endl;
652- std::cout << res.size () << std::endl;
653- for (size_t j = 0 ; j < res.size (); j++) {
654- float *p = (float *)&res[j];
655- for (size_t k = 0 ; k < 7 ; k++) {
656- std::cout << p[k] << " , " ;
659+ for (int b = 0 ; b < fcount; b++) {
660+ auto & res = batch_res[b];
661+ // std::cout << res.size() << std::endl;
662+ cv::Mat img = cv::imread (std::string (argv[2 ]) + " /" + file_names[f - BATCH_SIZE + 1 + b]);
663+ for (size_t j = 0 ; j < res.size (); j++) {
664+ float *p = (float *)&res[j];
665+ for (size_t k = 0 ; k < 7 ; k++) {
666+ // std::cout << p[k] << ", ";
667+ }
668+ // std::cout << std::endl;
669+ cv::Rect r = get_rect (img, res[j].bbox );
670+ cv::rectangle (img, r, cv::Scalar (0x27 , 0xC1 , 0x36 ), 2 );
671+ cv::putText (img, std::to_string ((int )res[j].class_id ), cv::Point (r.x , r.y - 1 ), cv::FONT_HERSHEY_PLAIN, 1.2 , cv::Scalar (0xFF , 0xFF , 0xFF ), 2 );
657672 }
658- std::cout << std::endl;
659- cv::Rect r = get_rect (img, res[j].bbox );
660- cv::rectangle (img, r, cv::Scalar (0x27 , 0xC1 , 0x36 ), 2 );
661- cv::putText (img, std::to_string ((int )res[j].class_id ), cv::Point (r.x , r.y - 1 ), cv::FONT_HERSHEY_PLAIN, 1.2 , cv::Scalar (0xFF , 0xFF , 0xFF ), 2 );
673+ cv::imwrite (" _" + file_names[f - BATCH_SIZE + 1 + b], img);
662674 }
663- cv::imwrite ( " _ " + f, img) ;
675+ fcount = 0 ;
664676 }
665677
666678 // Destroy the engine
0 commit comments