-
Notifications
You must be signed in to change notification settings - Fork 109
Open
Description
When
so
then
but in your code https://github.com/zhangxiaosong18/FreeAnchor/blob/master/maskrcnn_benchmark/modeling/rpn/free_anchor_loss.py#L161,
you just use (matched_cls_prob in your code) as
,
that means you just ignore the other predicted classes which not matching the target class, and I think it's different with retinanet_cls_loss defined in https://github.com/zhangxiaosong18/FreeAnchor/blob/master/maskrcnn_benchmark/modeling/rpn/retinanet_loss.py#L142.
I try to rewrite the code calculating matched_cls_prob as blew:
labels_mul = torch.zeros([len(labels_), self.num_classes])
for i in range(len(labels_)):
labels_mul[i, labels_[i]] = 1
labels_mul = labels_mul.unsqueeze(1).repeat(1, self.pre_anchor_topk, 1)
loss_mul_class = nn.BCELoss(reduction="none")(cls_prob_[matched], labels_mul).sum(dim=-1)
matched_cls_prob = (-loss_mul_class).exp()
Did I get it wrong ? @zhangxiaosong18
Metadata
Metadata
Assignees
Labels
No labels