@@ -15,66 +15,57 @@ def __init__(self, global_args, post_proc_args):
1515 self .max_boxes = post_proc_args ['max_boxes' ]
1616
1717 def __filter_cls_boxes_s__ (self , boxes_s , confs_s , pi_s ):
18- cls_boxes_sl = list ()
19- cls_confs_sl = list ()
20- cls_labels_sl = list ()
18+ boxes_sl = list ()
19+ confs_sl = list ()
20+ labels_sl = list ()
2121
2222 norm_pi_s = pi_s / torch .max (pi_s )
23+ keep_idxes = torch .nonzero (norm_pi_s > self .pi_thresh ).view (- 1 )
24+ boxes_s = boxes_s [:, keep_idxes ]
25+ confs_s = confs_s [:, keep_idxes ]
26+
2327 for c in range (self .n_classes - 1 ):
24- cls_boxes_sc = boxes_s [c ]
25- cls_confs_sc = confs_s [c ]
26- # cls_pi_sc = norm_pi_s.clone()
28+ boxes_sc = boxes_s [c ]
29+ confs_sc = confs_s [c ]
2730
28- if len (cls_boxes_sc ) == 0 :
31+ if len (boxes_sc ) == 0 :
2932 continue
3033
31- # print(cls_boxes_sc.shape)
32- keep_idxes = torch .nonzero (norm_pi_s > self .pi_thresh ).view (- 1 )
33- cls_boxes_sc = cls_boxes_sc [keep_idxes ]
34- cls_confs_sc = cls_confs_sc [keep_idxes ]
35-
36- # print(cls_boxes_sc.shape)
37- keep_idxes = torch .nonzero (cls_confs_sc > self .conf_thresh ).view (- 1 )
38- cls_boxes_sc = cls_boxes_sc [keep_idxes ]
39- cls_confs_sc = cls_confs_sc [keep_idxes ]
34+ keep_idxes = torch .nonzero (confs_sc > self .conf_thresh ).view (- 1 )
35+ boxes_sc = boxes_sc [keep_idxes ]
36+ confs_sc = confs_sc [keep_idxes ]
4037 if keep_idxes .shape [0 ] == 0 :
4138 continue
42- # print(cls_boxes_sc.shape)
43- # print('')
4439
4540 if self .nms_thresh <= 0.0 :
46- cls_boxes_sc , cls_confs_sc = lib_util .sort_boxes_s (cls_boxes_sc , cls_confs_sc )
47- cls_boxes_sc , cls_confs_sc = cls_boxes_sc [:1 ], cls_confs_sc [:1 ]
41+ boxes_sc , confs_sc = lib_util .sort_boxes_s (boxes_sc , confs_sc )
42+ boxes_sc , confs_sc = boxes_sc [:1 ], confs_sc [:1 ]
4843 elif self .nms_thresh > 1.0 :
4944 pass
5045 else :
51- keep_idxes = nms (cls_boxes_sc , cls_confs_sc , self .nms_thresh )
46+ keep_idxes = nms (boxes_sc , confs_sc , self .nms_thresh )
5247 keep_idxes = keep_idxes .long ().view (- 1 )
53- cls_boxes_sc = cls_boxes_sc [keep_idxes ]
54- cls_confs_sc = cls_confs_sc [keep_idxes ].unsqueeze (dim = 1 )
48+ boxes_sc = boxes_sc [keep_idxes ]
49+ confs_sc = confs_sc [keep_idxes ].unsqueeze (dim = 1 )
5550
56- labels_css = torch .zeros (cls_confs_sc .shape ).float ().cuda ()
51+ labels_css = torch .zeros (confs_sc .shape ).float ().cuda ()
5752 labels_css += c
5853
59- cls_boxes_sl .append (cls_boxes_sc )
60- cls_confs_sl .append (cls_confs_sc )
61- cls_labels_sl .append (labels_css )
62- # exit()
54+ boxes_sl .append (boxes_sc )
55+ confs_sl .append (confs_sc )
56+ labels_sl .append (labels_css )
6357
64- if len (cls_boxes_sl ) > 0 :
65- boxes_s = torch .cat (cls_boxes_sl , dim = 0 )
66- confs_s = torch .cat (cls_confs_sl , dim = 0 )
67- labels_s = torch .cat (cls_labels_sl , dim = 0 )
58+ if len (boxes_sl ) > 0 :
59+ boxes_s = torch .cat (boxes_sl , dim = 0 )
60+ confs_s = torch .cat (confs_sl , dim = 0 )
61+ labels_s = torch .cat (labels_sl , dim = 0 )
6862 else :
6963 boxes_s = torch .zeros ((1 , 4 )).float ().cuda ()
7064 confs_s = torch .zeros ((1 , 1 )).float ().cuda ()
7165 labels_s = torch .zeros ((1 , 1 )).float ().cuda ()
72-
73- boxes_s , confs_s , labels_s = lib_util .sort_boxes_s (boxes_s , confs_s , labels_s )
7466 return boxes_s , confs_s , labels_s
7567
7668 def forward (self , mu , prob , pi ):
77- # print('mu', torch.min(mu), torch.max(mu))
7869 boxes = mu .transpose (1 , 2 ).clone ()
7970 boxes [:, :, [0 , 2 ]] = boxes [:, :, [0 , 2 ]] * (self .input_size [1 ] / self .coord_range [1 ])
8071 boxes [:, :, [1 , 3 ]] = boxes [:, :, [1 , 3 ]] * (self .input_size [0 ] / self .coord_range [0 ])
@@ -85,7 +76,7 @@ def forward(self, mu, prob, pi):
8576 boxes_l , confs_l , labels_l = list (), list (), list ()
8677 for i , (boxes_s , confs_s ) in enumerate (zip (boxes , confs )):
8778 boxes_s , confs_s , labels_s = self .__filter_cls_boxes_s__ (boxes_s , confs_s , pi [i , 0 ])
88- boxes_l .append (boxes_s [: self . max_boxes ] )
89- confs_l .append (confs_s [: self . max_boxes ] )
90- labels_l .append (labels_s [: self . max_boxes ] + 1 )
79+ boxes_l .append (boxes_s )
80+ confs_l .append (confs_s )
81+ labels_l .append (labels_s + 1 )
9182 return boxes_l , confs_l , labels_l
0 commit comments