Skip to content

Commit 2d72433

Browse files
authored
add Corner losses (open-mmlab#2901)
* add ae loss and gaussian focal loss * update losses/__init__.py * formatting * Update __init__.py * update losses * add weighted_loss in gaussian_focal_loss * rename AELoss, move comment to docstring * fix formatting
1 parent faf3038 commit 2d72433

File tree

3 files changed

+163
-1
lines changed

3 files changed

+163
-1
lines changed

mmdet/models/losses/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .accuracy import Accuracy, accuracy
2+
from .ae_loss import AssociativeEmbeddingLoss
23
from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
34
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
45
cross_entropy, mask_cross_entropy)
56
from .focal_loss import FocalLoss, sigmoid_focal_loss
7+
from .gaussian_focal_loss import GaussianFocalLoss
68
from .ghm_loss import GHMC, GHMR
79
from .iou_loss import (BoundedIoULoss, GIoULoss, IoULoss, bounded_iou_loss,
810
iou_loss)
@@ -18,5 +20,5 @@
1820
'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
1921
'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'GHMC', 'GHMR', 'reduce_loss',
2022
'weight_reduce_loss', 'weighted_loss', 'L1Loss', 'l1_loss', 'isr_p',
21-
'carl_loss'
23+
'carl_loss', 'AssociativeEmbeddingLoss', 'GaussianFocalLoss'
2224
]

mmdet/models/losses/ae_loss.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from ..builder import LOSSES
6+
7+
8+
def ae_loss_per_image(tl_preds, br_preds, match):
9+
"""Associative Embedding Loss in one image.
10+
11+
Associative Embedding Loss including two parts: pull loss and push loss.
12+
Pull loss makes embedding vectors from same object closer to each other.
13+
Push loss distinguish embedding vector from different objects, and makes
14+
the gap between them is large enough.
15+
16+
During computing, usually there are 3 cases:
17+
- no object in image: both pull loss and push loss will be 0.
18+
- one object in image: push loss will be 0 and pull loss is computed
19+
by the two corner of the only object.
20+
- more than one objects in image: pull loss is computed by corner pairs
21+
from each object, push loss is computed by each object with all
22+
other objects. We use confusion matrix with 0 in diagonal to
23+
compute the push loss.
24+
25+
Args:
26+
tl_preds (tensor): Embedding feature map of left-top corner.
27+
br_preds (tensor): Embedding feature map of bottim-right corner.
28+
match (list): Downsampled coordinates pair of each ground truth box.
29+
"""
30+
31+
tl_list, br_list, me_list = [], [], []
32+
if len(match) == 0: # no object in image
33+
pull_loss = tl_preds.sum()[None] * 0.
34+
push_loss = tl_preds.sum()[None] * 0.
35+
else:
36+
for m in match:
37+
[tl_y, tl_x], [br_y, br_x] = m
38+
tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
39+
br_e = br_preds[:, br_y, br_x].view(-1, 1)
40+
tl_list.append(tl_e)
41+
br_list.append(br_e)
42+
me_list.append((tl_e + br_e) / 2.0)
43+
44+
tl_list = torch.cat(tl_list)
45+
br_list = torch.cat(br_list)
46+
me_list = torch.cat(me_list)
47+
48+
assert tl_list.size() == br_list.size()
49+
50+
# N is object number in image, M is dimension of embedding vector
51+
N, M = tl_list.size()
52+
53+
pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
54+
pull_loss = pull_loss.sum() / N
55+
56+
margin = 1 # exp setting of CornerNet, details in section 3.3 of paper
57+
58+
# confusion matrix of push loss
59+
conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
60+
conf_weight = 1 - torch.eye(N).type_as(me_list)
61+
conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
62+
63+
if N > 1: # more than one object in current image
64+
push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
65+
66+
return pull_loss, push_loss
67+
68+
69+
@LOSSES.register_module()
70+
class AssociativeEmbeddingLoss(nn.Module):
71+
"""Associative Embedding Loss.
72+
73+
More details can be found in
74+
`Associative Embedding <https://arxiv.org/abs/1611.05424>`_ and
75+
`CornerNet <https://arxiv.org/abs/1808.01244>`_ .
76+
Code is modified from `kp_utils.py <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L180>`_ # noqa: E501
77+
78+
Args:
79+
pull_weight (float): Loss weight for corners from same object.
80+
push_weight (float): Loss weight for corners from different object.
81+
"""
82+
83+
def __init__(self, pull_weight=0.25, push_weight=0.25):
84+
super(AssociativeEmbeddingLoss, self).__init__()
85+
self.pull_weight = pull_weight
86+
self.push_weight = push_weight
87+
88+
def forward(self, pred, target, match):
89+
batch = pred.size(0)
90+
pull_all, push_all = 0.0, 0.0
91+
for i in range(batch):
92+
pull, push = ae_loss_per_image(pred[i], target[i], match[i])
93+
94+
pull_all += self.pull_weight * pull
95+
push_all += self.push_weight * push
96+
97+
return pull_all, push_all
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch.nn as nn
2+
3+
from ..builder import LOSSES
4+
from .utils import weighted_loss
5+
6+
7+
@weighted_loss
8+
def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0):
9+
eps = 1e-12
10+
pos_weights = gaussian_target.eq(1)
11+
neg_weights = (1 - gaussian_target).pow(gamma)
12+
pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
13+
neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
14+
return pos_loss + neg_loss
15+
16+
17+
@LOSSES.register_module()
18+
class GaussianFocalLoss(nn.Module):
19+
""" GaussianFocalLoss is a variant of focal loss.
20+
21+
More details can be found in the `paper
22+
<https://arxiv.org/abs/1808.01244>`_
23+
Code is modified from `kp_utils.py
24+
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
25+
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
26+
not 0/1 binary target.
27+
28+
Args:
29+
alpha (float): Power of prediction.
30+
gamma (float): Power of target for negtive samples.
31+
reduction (str): Options are "none", "mean" and "sum".
32+
loss_weight (float): Loss weight of current loss.
33+
"""
34+
35+
def __init__(self,
36+
alpha=2.0,
37+
gamma=4.0,
38+
reduction='mean',
39+
loss_weight=1.0):
40+
super(GaussianFocalLoss, self).__init__()
41+
self.alpha = alpha
42+
self.gamma = gamma
43+
self.reduction = reduction
44+
self.loss_weight = loss_weight
45+
46+
def forward(self,
47+
pred,
48+
target,
49+
weight=None,
50+
avg_factor=None,
51+
reduction_override=None):
52+
assert reduction_override in (None, 'none', 'mean', 'sum')
53+
reduction = (
54+
reduction_override if reduction_override else self.reduction)
55+
loss_reg = self.loss_weight * gaussian_focal_loss(
56+
pred,
57+
target,
58+
weight,
59+
alpha=self.alpha,
60+
gamma=self.gamma,
61+
reduction=reduction,
62+
avg_factor=avg_factor)
63+
return loss_reg

0 commit comments

Comments
 (0)