@@ -23,12 +23,12 @@ def __filter_cls_boxes_s__(self, boxes_s, confs_s, pi_s):
2323 for c in range (self .n_classes - 1 ):
2424 cls_boxes_sc = boxes_s [c ]
2525 cls_confs_sc = confs_s [c ]
26- cls_pi_sc = norm_pi_s .clone ()
26+ # cls_pi_sc = norm_pi_s.clone()
2727
2828 if len (cls_boxes_sc ) == 0 :
2929 continue
3030
31- keep_idxes = torch .nonzero (cls_pi_sc > self .pi_thresh ).view (- 1 )
31+ keep_idxes = torch .nonzero (norm_pi_s > self .pi_thresh ).view (- 1 )
3232 cls_boxes_sc = cls_boxes_sc [keep_idxes ]
3333 cls_confs_sc = cls_confs_sc [keep_idxes ]
3434
@@ -50,7 +50,7 @@ def __filter_cls_boxes_s__(self, boxes_s, confs_s, pi_s):
5050 cls_confs_sc = cls_confs_sc [keep_idxes ].unsqueeze (dim = 1 )
5151
5252 labels_css = torch .zeros (cls_confs_sc .shape ).float ().cuda ()
53- labels_css += ( c + 1 )
53+ labels_css += c
5454
5555 cls_boxes_sl .append (cls_boxes_sc )
5656 cls_confs_sl .append (cls_confs_sc )
@@ -69,6 +69,7 @@ def __filter_cls_boxes_s__(self, boxes_s, confs_s, pi_s):
6969 return boxes_s , confs_s , labels_s
7070
7171 def forward (self , mu , prob , pi ):
72+ # print('mu', torch.min(mu), torch.max(mu))
7273 boxes = mu .transpose (1 , 2 ).clone ()
7374 boxes [:, :, [0 , 2 ]] = boxes [:, :, [0 , 2 ]] * (self .input_size [1 ] / self .coord_range [1 ])
7475 boxes [:, :, [1 , 3 ]] = boxes [:, :, [1 , 3 ]] * (self .input_size [0 ] / self .coord_range [0 ])
@@ -81,5 +82,5 @@ def forward(self, mu, prob, pi):
8182 boxes_s , confs_s , labels_s = self .__filter_cls_boxes_s__ (boxes_s , confs_s , pi [i , 0 ])
8283 boxes_l .append (boxes_s [:self .max_boxes ])
8384 confs_l .append (confs_s [:self .max_boxes ])
84- labels_l .append (labels_s [:self .max_boxes ])
85+ labels_l .append (labels_s [:self .max_boxes ] + 1 )
8586 return boxes_l , confs_l , labels_l
0 commit comments