Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit b2f8749

Browse files
committed
update focal loss
1 parent ef48f43 commit b2f8749

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

DiceLoss/dice_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class BinaryDiceLoss(nn.Module):
4242

4343
def __init__(self, ignore_index=None, reduction='mean',**kwargs):
4444
super(BinaryDiceLoss, self).__init__()
45-
self.smooth = 1 # suggest set a large number when TP is large
45+
self.smooth = 1 # suggest set a large number when target area is large,like '10|100'
4646
self.ignore_index = ignore_index
4747
self.reduction = reduction
4848
self.batch_dice = False # treat a large map when True
@@ -67,7 +67,7 @@ def forward(self, output, target, use_sigmoid=True):
6767
target = target.contiguous().view(dim0, -1).float()
6868

6969
num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
70-
den = torch.sum(output.pow(2) + target.pow(2), dim=1) + self.smooth
70+
den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth
7171

7272
loss = 1 - (num / den)
7373

FocalLoss/focal_loss.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,51 +18,51 @@ class BinaryFocalLoss(nn.Module):
1818
balance_index: (int) balance class index, should be specific when alpha is float
1919
"""
2020

21-
def __init__(self, alpha=[1.0, 1.0], gamma=2, ignore_index=None, reduction='mean'):
21+
def __init__(self, alpha=3, gamma=2, ignore_index=None, reduction='mean',**kwargs):
2222
super(BinaryFocalLoss, self).__init__()
23-
if alpha is None:
24-
alpha = [0.25, 0.75]
2523
self.alpha = alpha
2624
self.gamma = gamma
27-
self.smooth = 1e-6
25+
self.smooth = 1e-6 # set '1e-4' when train with FP16
2826
self.ignore_index = ignore_index
2927
self.reduction = reduction
3028

3129
assert self.reduction in ['none', 'mean', 'sum']
3230

33-
if self.alpha is None:
34-
self.alpha = torch.ones(2)
35-
elif isinstance(self.alpha, (list, np.ndarray)):
36-
self.alpha = np.asarray(self.alpha)
37-
self.alpha = np.reshape(self.alpha, (2))
38-
assert self.alpha.shape[0] == 2, \
39-
'the `alpha` shape is not match the number of class'
40-
elif isinstance(self.alpha, (float, int)):
41-
self.alpha = np.asarray([self.alpha, 1.0 - self.alpha], dtype=np.float).view(2)
31+
# if self.alpha is None:
32+
# self.alpha = torch.ones(2)
33+
# elif isinstance(self.alpha, (list, np.ndarray)):
34+
# self.alpha = np.asarray(self.alpha)
35+
# self.alpha = np.reshape(self.alpha, (2))
36+
# assert self.alpha.shape[0] == 2, \
37+
# 'the `alpha` shape is not match the number of class'
38+
# elif isinstance(self.alpha, (float, int)):
39+
# self.alpha = np.asarray([self.alpha, 1.0 - self.alpha], dtype=np.float).view(2)
4240

43-
else:
44-
raise TypeError('{} not supported'.format(type(self.alpha)))
41+
# else:
42+
# raise TypeError('{} not supported'.format(type(self.alpha)))
4543

4644
def forward(self, output, target):
4745
prob = torch.sigmoid(output)
4846
prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth)
4947

48+
valid_mask = None
49+
if self.ignore_index is not None:
50+
valid_mask = (target != self.ignore_index).float()
51+
5052
pos_mask = (target == 1).float()
5153
neg_mask = (target == 0).float()
54+
if valid_mask is not None:
55+
pos_mask = pos_mask * valid_mask
56+
neg_mask = neg_mask * valid_mask
57+
58+
pos_weight = (pos_mask * torch.pow(1 - prob, self.gamma)).detach()
59+
pos_loss = -torch.sum(pos_weight * torch.log(prob)) / (torch.sum(pos_weight) + 1e-4)
60+
61+
62+
neg_weight = (neg_mask * torch.pow(prob, self.gamma)).detach()
63+
neg_loss = -self.alpha * torch.sum(neg_weight * F.logsigmoid(-logit)) / (torch.sum(neg_weight) + 1e-4)
64+
loss = pos_loss + neg_loss
5265

53-
pos_loss = -self.alpha[0] * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob) * pos_mask
54-
neg_loss = -self.alpha[1] * torch.pow(prob, self.gamma) * \
55-
torch.log(torch.sub(1.0, prob)) * neg_mask
56-
57-
neg_loss = neg_loss.sum()
58-
pos_loss = pos_loss.sum()
59-
num_pos = pos_mask.view(pos_mask.size(0), -1).sum()
60-
num_neg = neg_mask.view(neg_mask.size(0), -1).sum()
61-
62-
if num_pos == 0:
63-
loss = neg_loss
64-
else:
65-
loss = pos_loss / num_pos + neg_loss / num_neg
6666
return loss
6767

6868

0 commit comments

Comments
 (0)