@@ -155,7 +155,7 @@ void doInference(IExecutionContext& context, cudaStream_t& stream, void **buffer
155155 cudaStreamSynchronize (stream);
156156}
157157
158- bool parse_args (int argc, char ** argv, std::string& wts, std::string& engine, float & gd, float & gw, std::string& img_dir) {
158+ bool parse_args (int argc, char ** argv, std::string& wts, std::string& engine, float & gd, float & gw, std::string& img_dir, std::string& labels_filename ) {
159159 if (argc < 4 ) return false ;
160160 if (std::string (argv[1 ]) == " -s" && (argc == 5 || argc == 7 )) {
161161 wts = std::string (argv[2 ]);
@@ -182,9 +182,10 @@ bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, fl
182182 } else {
183183 return false ;
184184 }
185- } else if (std::string (argv[1 ]) == " -d" && argc == 4 ) {
185+ } else if (std::string (argv[1 ]) == " -d" && argc == 5 ) {
186186 engine = std::string (argv[2 ]);
187187 img_dir = std::string (argv[3 ]);
188+ labels_filename = std::string (argv[4 ]);
188189 } else {
189190 return false ;
190191 }
@@ -216,12 +217,6 @@ std::vector<cv::Mat> process_mask(const float* proto, std::vector<Yolo::Detectio
216217 }
217218 e = 1 .0f / (1 .0f + expf (-e));
218219 mask_mat.at <float >(y, x) = e;
219- // if (e > 0.5) {
220- // // TODO(Call for PR): Use different colors for different class ids
221- // mask_mat.at<cv::Vec3b>(y, x)[2] = 0xFF;
222- // mask_mat.at<cv::Vec3b>(y, x)[1] = 0x38;
223- // mask_mat.at<cv::Vec3b>(y, x)[0] = 0x38;
224- // }
225220 }
226221 }
227222 cv::resize (mask_mat, mask_mat, cv::Size (INPUT_W, INPUT_H));
@@ -251,7 +246,7 @@ cv::Mat scale_mask(cv::Mat mask, cv::Mat img) {
251246 return res;
252247}
253248
254- void draw_mask_bbox (cv::Mat& img, std::vector<Yolo::Detection>& dets, std::vector<cv::Mat>& masks) {
249+ void draw_mask_bbox (cv::Mat& img, std::vector<Yolo::Detection>& dets, std::vector<cv::Mat>& masks, std::unordered_map< int , std::string>& labels_map ) {
255250 static std::vector<uint32_t > colors = {0xFF3838 , 0xFF9D97 , 0xFF701F , 0xFFB21D , 0xCFD231 , 0x48F90A ,
256251 0x92CC17 , 0x3DDB86 , 0x1A9334 , 0x00D4BB , 0x2C99A8 , 0x00C2FF ,
257252 0x344593 , 0x6473FF , 0x0018EC , 0x8438FF , 0x520085 , 0xCB38FF ,
@@ -273,8 +268,23 @@ void draw_mask_bbox(cv::Mat& img, std::vector<Yolo::Detection>& dets, std::vecto
273268 }
274269
275270 cv::rectangle (img, r, bgr, 2 );
276- // TODO(Call for PR): convert class id to class name
277- cv::putText (img, std::to_string ((int )dets[i].class_id ), cv::Point (r.x , r.y - 1 ), cv::FONT_HERSHEY_PLAIN, 1.2 , cv::Scalar::all (0xFF ), 2 );
271+
272+ // Get the size of the text
273+ cv::Size textSize = cv::getTextSize (labels_map[(int )dets[i].class_id ] + " " + to_string_with_precision (dets[i].conf ), cv::FONT_HERSHEY_PLAIN, 1.2 , 2 , NULL );
274+ // Set the top left corner of the rectangle
275+ cv::Point topLeft (r.x , r.y - textSize.height );
276+
277+ // Set the bottom right corner of the rectangle
278+ cv::Point bottomRight (r.x + textSize.width , r.y + textSize.height );
279+
280+ // Set the thickness of the rectangle lines
281+ int lineThickness = 2 ;
282+
283+ // Draw the rectangle on the image
284+ cv::rectangle (img, topLeft, bottomRight, bgr, -1 );
285+
286+ cv::putText (img, labels_map[(int )dets[i].class_id ] + " " + to_string_with_precision (dets[i].conf ), cv::Point (r.x , r.y + 4 ), cv::FONT_HERSHEY_PLAIN, 1.2 , cv::Scalar::all (0xFF ), 2 );
287+
278288 }
279289}
280290
@@ -283,12 +293,14 @@ int main(int argc, char** argv) {
283293
284294 std::string wts_name = " " ;
285295 std::string engine_name = " " ;
296+ std::string labels_filename = " " ;
297+
286298 float gd = 0 .0f , gw = 0 .0f ;
287299 std::string img_dir;
288- if (!parse_args (argc, argv, wts_name, engine_name, gd, gw, img_dir)) {
300+ if (!parse_args (argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename )) {
289301 std::cerr << " arguments not right!" << std::endl;
290302 std::cerr << " ./yolov5_seg -s [.wts] [.engine] [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl;
291- std::cerr << " ./yolov5_seg -d [.engine] ../samples // deserialize plan file and run inference" << std::endl;
303+ std::cerr << " ./yolov5_seg -d [.engine] ../samples coco.txt // deserialize plan file, read the labels file and run inference" << std::endl;
292304 return -1 ;
293305 }
294306
@@ -328,6 +340,18 @@ int main(int argc, char** argv) {
328340 std::cerr << " read_files_in_dir failed." << std::endl;
329341 return -1 ;
330342 }
343+
344+ // read the txt file for classnames
345+ std::ifstream labels_file (labels_filename, std::ios::binary);
346+ if (!labels_file.good ()) {
347+ std::cerr << " read " << labels_filename << " error!" << std::endl;
348+ return -1 ;
349+ }
350+ std::unordered_map<int , std::string> labels_map;
351+ read_labels (labels_filename, labels_map);
352+
353+ assert (CLASS_NUM == labels_map.size ());
354+
331355
332356 static float prob[BATCH_SIZE * OUTPUT_SIZE1];
333357 static float proto[BATCH_SIZE * OUTPUT_SIZE2];
@@ -398,7 +422,7 @@ int main(int argc, char** argv) {
398422 cv::Mat img = imgs_buffer[b];
399423
400424 auto masks = process_mask (&proto[b * OUTPUT_SIZE2], res);
401- draw_mask_bbox (img, res, masks);
425+ draw_mask_bbox (img, res, masks, labels_map );
402426 cv::imwrite (" _" + file_names[f - fcount + 1 + b], img);
403427 }
404428 fcount = 0 ;
0 commit comments