Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit f78750b

Browse files
committed
fix binary focal loss
1 parent 4247f6a commit f78750b

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

FocalLoss/FocalLoss_test.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,39 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
from focal_loss import FocalLoss_Ori,BinaryFocalLoss
7+
from focal_loss import FocalLoss_Ori, BinaryFocalLoss
88

99

1010
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
1111

1212
def test_BFL():
13-
output = 40*(torch.randint(0,2,(1,1,32,32,32))-0.5)
14-
target = torch.zeros_like(output)
15-
target[output>0] = 1
16-
# target = torch.randint(0,2,(1,1,32,32,32))
17-
criterion = BinaryFocalLoss()
18-
loss = criterion(output,target)
19-
print(loss.item())
13+
import matplotlib.pyplot as plt
14+
from torch import optim
15+
torch.manual_seed(123)
16+
shape = (4, 1, 32, 32, 32)
17+
datas = 40 * (torch.randint(0, 2, shape) - 0.5)
18+
target = torch.zeros_like(datas) + torch.randint(0, 2, size=shape)
19+
model = nn.Sequential(*[nn.Conv3d(1, 16, kernel_size=3, padding=1, stride=1),
20+
nn.BatchNorm3d(16),
21+
nn.ReLU(),
22+
nn.Conv3d(16, 1, kernel_size=3, padding=1, stride=1)])
23+
24+
criterion = BinaryFocalLoss()
25+
losses = []
26+
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
27+
for step in range(100):
28+
out = model(datas)
29+
loss = criterion(out, target)
30+
optimizer.zero_grad()
31+
loss.backward()
32+
optimizer.step()
33+
losses.append(loss.item())
34+
if step % 10 == 0:
35+
print(step)
36+
37+
plt.plot(losses)
38+
plt.show()
39+
2040

2141
def test_focal():
2242
num_class = 5

FocalLoss/focal_loss.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ class BinaryFocalLoss(nn.Module):
1717
balance_index: (int) balance class index, should be specific when alpha is float
1818
"""
1919

20-
def __init__(self, alpha=3, gamma=2, ignore_index=None, reduction='mean',**kwargs):
20+
def __init__(self, alpha=3, gamma=2, ignore_index=None, reduction='mean', **kwargs):
2121
super(BinaryFocalLoss, self).__init__()
2222
self.alpha = alpha
2323
self.gamma = gamma
24-
self.smooth = 1e-6 # set '1e-4' when train with FP16
24+
self.smooth = 1e-6 # set '1e-4' when train with FP16
2525
self.ignore_index = ignore_index
2626
self.reduction = reduction
2727

@@ -55,17 +55,15 @@ def forward(self, output, target):
5555
neg_mask = neg_mask * valid_mask
5656

5757
pos_weight = (pos_mask * torch.pow(1 - prob, self.gamma)).detach()
58-
pos_loss = -torch.sum(pos_weight * torch.log(prob)) / (torch.sum(pos_weight) + 1e-4)
59-
60-
58+
pos_loss = -pos_weight * torch.log(prob) #/ (torch.sum(pos_weight) + 1e-4)
59+
6160
neg_weight = (neg_mask * torch.pow(prob, self.gamma)).detach()
62-
neg_loss = -self.alpha * torch.sum(neg_weight * F.logsigmoid(-output)) / (torch.sum(neg_weight) + 1e-4)
61+
neg_loss = -self.alpha * neg_weight * F.logsigmoid(-output) #/ (torch.sum(neg_weight) + 1e-4)
6362
loss = pos_loss + neg_loss
64-
63+
loss = loss.mean()
6564
return loss
6665

6766

68-
6967
class FocalLoss_Ori(nn.Module):
7068
"""
7169
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
@@ -79,7 +77,7 @@ class FocalLoss_Ori(nn.Module):
7977
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
8078
"""
8179

82-
def __init__(self, num_class, alpha=[0.25,0.75], gamma=2, balance_index=-1, size_average=True):
80+
def __init__(self, num_class, alpha=[0.25, 0.75], gamma=2, balance_index=-1, size_average=True):
8381
super(FocalLoss_Ori, self).__init__()
8482
self.num_class = num_class
8583
self.alpha = alpha
@@ -90,11 +88,11 @@ def __init__(self, num_class, alpha=[0.25,0.75], gamma=2, balance_index=-1, size
9088
if isinstance(self.alpha, (list, tuple)):
9189
assert len(self.alpha) == self.num_class
9290
self.alpha = torch.Tensor(list(self.alpha))
93-
elif isinstance(self.alpha, (float,int)):
91+
elif isinstance(self.alpha, (float, int)):
9492
assert 0 < self.alpha < 1.0, 'alpha should be in `(0,1)`)'
9593
assert balance_index > -1
9694
alpha = torch.ones((self.num_class))
97-
alpha *= 1-self.alpha
95+
alpha *= 1 - self.alpha
9896
alpha[balance_index] = self.alpha
9997
self.alpha = alpha
10098
elif isinstance(self.alpha, torch.Tensor):
@@ -107,9 +105,9 @@ def forward(self, logit, target):
107105
if logit.dim() > 2:
108106
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
109107
logit = logit.view(logit.size(0), logit.size(1), -1)
110-
logit = logit.transpose(1, 2).contiguous() # [N,C,d1*d2..] -> [N,d1*d2..,C]
111-
logit = logit.view(-1, logit.size(-1)) # [N,d1*d2..,C]-> [N*d1*d2..,C]
112-
target = target.view(-1, 1) # [N,d1,d2,...]->[N*d1*d2*...,1]
108+
logit = logit.transpose(1, 2).contiguous() # [N,C,d1*d2..] -> [N,d1*d2..,C]
109+
logit = logit.view(-1, logit.size(-1)) # [N,d1*d2..,C]-> [N*d1*d2..,C]
110+
target = target.view(-1, 1) # [N,d1,d2,...]->[N*d1*d2*...,1]
113111

114112
# -----------legacy way------------
115113
# idx = target.cpu().long()
@@ -120,19 +118,17 @@ def forward(self, logit, target):
120118
# pt = (one_hot_key * logit).sum(1) + epsilon
121119

122120
# ----------memory saving way--------
123-
pt = logit.gather(1, target).view(-1) + self.eps # avoid apply
121+
pt = logit.gather(1, target).view(-1) + self.eps # avoid apply
124122
logpt = pt.log()
125123

126124
if self.alpha.device != logpt.device:
127125
alpha = self.alpha.to(logpt.device)
128-
alpha_class = alpha.gather(0,target.view(-1))
129-
logpt = alpha_class*logpt
126+
alpha_class = alpha.gather(0, target.view(-1))
127+
logpt = alpha_class * logpt
130128
loss = -1 * torch.pow(torch.sub(1.0, pt), self.gamma) * logpt
131129

132130
if self.size_average:
133131
loss = loss.mean()
134132
else:
135133
loss = loss.sum()
136134
return loss
137-
138-

0 commit comments

Comments
 (0)