Skip to content

Commit d5012b0

Browse files
committed
Fix the softmax warning.
1 parent ba5ab6d commit d5012b0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lib/nets/network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _region_proposal(self, net_conv):
240240

241241
# change it so that the score has 2 as its channel size
242242
rpn_cls_score_reshape = rpn_cls_score.view(1, 2, -1, rpn_cls_score.size()[-1]) # batch * 2 * (num_anchors*h) * w
243-
rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape)
243+
rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, dim=1)
244244

245245
# Move channel to the last dimenstion, to fit the input of python functions
246246
rpn_cls_prob = rpn_cls_prob_reshape.view_as(rpn_cls_score).permute(0, 2, 3, 1) # batch * h * w * (num_anchors * 2)
@@ -275,7 +275,7 @@ def _region_proposal(self, net_conv):
275275
def _region_classification(self, fc7):
276276
cls_score = self.cls_score_net(fc7)
277277
cls_pred = torch.max(cls_score, 1)[1]
278-
cls_prob = F.softmax(cls_score)
278+
cls_prob = F.softmax(cls_score, dim=1)
279279
bbox_pred = self.bbox_pred_net(fc7)
280280

281281
self._predictions["cls_score"] = cls_score

0 commit comments

Comments
 (0)