11import os
22import abc
3+ import sys
4+ import tqdm
5+ import pickle
36import numpy as np
47import scipy .misc
58from lib import util
9+ from lib .external .dataset import factory
10+ from . import tester_util
611
712
813class TesterABC (abc .ABC ):
@@ -12,7 +17,7 @@ def __init__(self, global_args, tester_args):
1217 self .result_root = tester_args ['result_root' ]
1318
1419 @ abc .abstractmethod
15- def test (self , framework , data_loader , result_dir ):
20+ def run (self , framework , data_loader , result_dir ):
1621 pass
1722
1823
@@ -22,14 +27,12 @@ def __init__(self, global_args, tester_args):
2227 self .n_samples = tester_args ['n_samples' ]
2328 self .max_boxes = tester_args ['max_boxes' ]
2429 self .conf_thresh = tester_args ['conf_thresh' ]
25- self .coord_size = (global_args ['coord_h' ], global_args ['coord_w' ])
2630
27- def test (self , framework , data_loader , result_dir ):
31+ def run (self , framework , data_loader , result_dir ):
2832 assert data_loader .batch_size == 1
2933 pre_proc = data_loader .dataset .pre_proc
3034 class_map = data_loader .dataset .get_number2name_map ()
31- save_dir = os .path .join (self .result_root , result_dir )
32- util .make_dir (save_dir )
35+ util .make_dir (result_dir )
3336
3437 for i , data_dict in enumerate (data_loader ):
3538 if i >= self .n_samples :
@@ -44,28 +47,109 @@ def test(self, framework, data_loader, result_dir):
4447 img_s = data_dict ['img' ][0 ]
4548 gt_boxes_s = data_dict ['boxes' ][0 ]
4649 gt_labels_s = data_dict ['labels' ][0 ]
47- img_size = data_dict ['img_size' ][0 ]
48- input_size = img_s .shape [:2 ]
4950
5051 sort_idx = 0
51- gt_img_path = os .path .join (save_dir , '%03d_%d_%s.png' % (i , sort_idx , 'gt' ))
52- gt_img_s = testutil .draw_boxes (img_s , gt_boxes_s , None , gt_labels_s ,
53- class_map , self .conf_thresh , self .max_boxes )
52+ gt_img_path = os .path .join (result_dir , '%03d_%d_%s.png' % (i , sort_idx , 'gt' ))
53+ gt_img_s = tester_util .draw_boxes (
54+ img_s , gt_boxes_s , None , gt_labels_s ,
55+ class_map , self .conf_thresh , self .max_boxes )
5456 scipy .misc .imsave (gt_img_path , gt_img_s )
55- del gt_img_s
57+ sort_idx += 1
5658
59+ # draw_boxes
60+ pred_img_path = os .path .join (result_dir , '%03d_%d_%s.png' % (i , sort_idx , 'pred' ))
61+ pred_img_s = tester_util .draw_boxes (
62+ img_s , pred_boxes_s , pred_confs_s , pred_labels_s ,
63+ class_map , self .conf_thresh , self .max_boxes )
64+ scipy .misc .imsave (pred_img_path , pred_img_s )
65+ sort_idx += 1
66+
67+
68+ class QuantTester (TesterABC ):
69+ def __init__ (self , global_args , tester_args ):
70+ super (QuantTester , self ).__init__ (global_args , tester_args )
71+ self .n_classes = global_args ['n_classes' ]
72+ self .imdb_name = tester_args ['dataset' ]
73+ self .iou_thresh = tester_args ['iou_thresh' ]
74+ assert self .imdb_name in ('voc_2007_test' , 'coco_2017_val' , 'coco_2017_test-dev' )
75+
76+ def test (self , framework , data_loader , result_dir_name ):
77+ assert data_loader .batch_size == 1
78+ num_samples = data_loader .dataset .__len__ ()
79+ all_boxes = [[[] for _ in range (num_samples )] for _ in range (self .n_classes )]
80+
81+ # import cv2
82+ s = 0
83+ times = list ()
84+ data_loader_pbar = tqdm (data_loader )
85+ for idx , data_dict in enumerate (data_loader_pbar ):
86+ output_dict , result_dict = framework .infer_forward (data_dict , flip = self .flip_test )
87+ times .append (result_dict ['time' ])
88+ mean_time = np .mean (times [10 :]) if len (times ) > 10 else np .mean (times )
89+ data_loader_pbar .set_description ('infer time: %.4f sec' % mean_time )
90+
91+ # total predict boxes shape : (batch, # pred box, 4)
92+ # total predict boxes confidence shape : (batch, # pred box, 1)
93+ # total predict boxes label shape : (batch, # pred box, 1)
94+ img_size_s = data_dict ['img_size' ].float ()[0 ]
95+ input_size = data_dict ['img' ].shape [2 :]
96+
97+ boxes_s = result_dict ['boxes_l' ][0 ]
98+ confs_s = result_dict ['confs_l' ][0 ]
99+ labels_s = result_dict ['labels_l' ][0 ]
100+
101+ boxes_s [:, [0 , 2 ]] *= (img_size_s [1 ] / input_size [1 ])
102+ boxes_s [:, [1 , 3 ]] *= (img_size_s [0 ] / input_size [0 ])
103+ boxes_s , confs_s , labels_s = \
104+ util .sort_boxes_s (boxes_s , confs_s , labels_s )
105+
106+ boxes_s = util .cvt_torch2numpy (boxes_s )
107+ confs_s = util .cvt_torch2numpy (confs_s )
108+ labels_s = util .cvt_torch2numpy (labels_s )
109+
110+ if len (confs_s .shape ) == 1 :
111+ confs_s = np .expand_dims (confs_s , axis = 1 )
112+ for i , (cls_box , cls_conf , cls_label ) in \
113+ enumerate (zip (boxes_s , confs_s , labels_s )):
114+
115+ cls_box_with_conf = np .concatenate ((cls_box , cls_conf ), axis = 0 )
116+ cls_box_with_conf = np .expand_dims (cls_box_with_conf , axis = 0 )
117+ all_boxes [int (cls_label )][s ].append (cls_box_with_conf )
118+
119+ for c in range (self .n_classes ):
120+ all_boxes [c ][s ] = np .concatenate (all_boxes [c ][s ], axis = 0 ) \
121+ if 0 < len (all_boxes [c ][s ]) else np .concatenate ([[]], axis = 0 )
57122 data_dict .clear ()
58- output_dict .clear ()
59- result_dict .clear ()
60- return save_dir
61-
62-
63- def save_img (save_dir , img_dict , key , size , range_ratio , idx , sort_idx , tag ):
64- if key in img_dict .keys ():
65- img = util .cvt_torch2numpy (img_dict [key ])[0 ] * range_ratio
66- if len (img .shape ) == 3 :
67- img = np .squeeze (np .transpose (img , (1 , 2 , 0 )))
68- img = scipy .misc .imresize (img , size , interp = 'bilinear' )
69- img_path = os .path .join (save_dir , '%03d_%d_%s.png' % (idx , sort_idx , tag ))
70- scipy .misc .imsave (img_path , img )
71- return sort_idx + 1
123+
124+ # create result directories
125+ if not os .path .exists (self .result_dir ):
126+ os .mkdir (self .result_dir )
127+ result_dir_path = os .path .join (self .result_dir , result_dir_name )
128+ if not os .path .exists (result_dir_path ):
129+ os .mkdir (result_dir_path )
130+
131+ dataset_root = data_loader .dataset .get_dataset_root ()
132+ if isinstance (dataset_root , list ):
133+ dataset_root = dataset_root [0 ]
134+ imdb = factory .get_imdb (self .imdb_name , dataset_root )
135+
136+ if 'coco' in self .imdb_name :
137+ sys_stdout = sys .stdout
138+ result_file_path = open (os .path .join (result_dir_path , 'ap_ar.txt' ), 'w' )
139+ sys .stdout = result_file_path
140+ imdb .evaluate_detections (all_boxes , result_dir_path )
141+ sys .stdout = sys_stdout
142+ result_file_path .close ()
143+
144+ else :
145+ det_file_path = os .path .join (result_dir_path , 'detection_results.pkl' )
146+ with open (det_file_path , 'wb' ) as det_file :
147+ pickle .dump (all_boxes , det_file , pickle .HIGHEST_PROTOCOL )
148+ result_msg = imdb .evaluate_detections (all_boxes , result_dir_path )
149+ result_file_path = os .path .join (result_dir_path , 'mean_ap.txt' )
150+ with open (result_file_path , 'w' ) as file :
151+ file .write (result_msg )
152+ os .remove (det_file_path )
153+
154+ all_boxes .clear ()
155+
0 commit comments