@@ -17,11 +17,11 @@ class BinaryFocalLoss(nn.Module):
1717 balance_index: (int) balance class index, should be specific when alpha is float
1818 """
1919
20- def __init__ (self , alpha = 3 , gamma = 2 , ignore_index = None , reduction = 'mean' ,** kwargs ):
20+ def __init__ (self , alpha = 3 , gamma = 2 , ignore_index = None , reduction = 'mean' , ** kwargs ):
2121 super (BinaryFocalLoss , self ).__init__ ()
2222 self .alpha = alpha
2323 self .gamma = gamma
24- self .smooth = 1e-6 # set '1e-4' when train with FP16
24+ self .smooth = 1e-6 # set '1e-4' when train with FP16
2525 self .ignore_index = ignore_index
2626 self .reduction = reduction
2727
@@ -55,17 +55,15 @@ def forward(self, output, target):
5555 neg_mask = neg_mask * valid_mask
5656
5757 pos_weight = (pos_mask * torch .pow (1 - prob , self .gamma )).detach ()
58- pos_loss = - torch .sum (pos_weight * torch .log (prob )) / (torch .sum (pos_weight ) + 1e-4 )
59-
60-
58+ pos_loss = - pos_weight * torch .log (prob ) #/ (torch.sum(pos_weight) + 1e-4)
59+
6160 neg_weight = (neg_mask * torch .pow (prob , self .gamma )).detach ()
62- neg_loss = - self .alpha * torch . sum ( neg_weight * F .logsigmoid (- output )) / (torch .sum (neg_weight ) + 1e-4 )
61+ neg_loss = - self .alpha * neg_weight * F .logsigmoid (- output ) # / (torch.sum(neg_weight) + 1e-4)
6362 loss = pos_loss + neg_loss
64-
63+ loss = loss . mean ()
6564 return loss
6665
6766
68-
6967class FocalLoss_Ori (nn .Module ):
7068 """
7169 This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
@@ -79,7 +77,7 @@ class FocalLoss_Ori(nn.Module):
7977 :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
8078 """
8179
82- def __init__ (self , num_class , alpha = [0.25 ,0.75 ], gamma = 2 , balance_index = - 1 , size_average = True ):
80+ def __init__ (self , num_class , alpha = [0.25 , 0.75 ], gamma = 2 , balance_index = - 1 , size_average = True ):
8381 super (FocalLoss_Ori , self ).__init__ ()
8482 self .num_class = num_class
8583 self .alpha = alpha
@@ -90,11 +88,11 @@ def __init__(self, num_class, alpha=[0.25,0.75], gamma=2, balance_index=-1, size
9088 if isinstance (self .alpha , (list , tuple )):
9189 assert len (self .alpha ) == self .num_class
9290 self .alpha = torch .Tensor (list (self .alpha ))
93- elif isinstance (self .alpha , (float ,int )):
91+ elif isinstance (self .alpha , (float , int )):
9492 assert 0 < self .alpha < 1.0 , 'alpha should be in `(0,1)`)'
9593 assert balance_index > - 1
9694 alpha = torch .ones ((self .num_class ))
97- alpha *= 1 - self .alpha
95+ alpha *= 1 - self .alpha
9896 alpha [balance_index ] = self .alpha
9997 self .alpha = alpha
10098 elif isinstance (self .alpha , torch .Tensor ):
@@ -107,9 +105,9 @@ def forward(self, logit, target):
107105 if logit .dim () > 2 :
108106 # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
109107 logit = logit .view (logit .size (0 ), logit .size (1 ), - 1 )
110- logit = logit .transpose (1 , 2 ).contiguous () # [N,C,d1*d2..] -> [N,d1*d2..,C]
111- logit = logit .view (- 1 , logit .size (- 1 )) # [N,d1*d2..,C]-> [N*d1*d2..,C]
112- target = target .view (- 1 , 1 ) # [N,d1,d2,...]->[N*d1*d2*...,1]
108+ logit = logit .transpose (1 , 2 ).contiguous () # [N,C,d1*d2..] -> [N,d1*d2..,C]
109+ logit = logit .view (- 1 , logit .size (- 1 )) # [N,d1*d2..,C]-> [N*d1*d2..,C]
110+ target = target .view (- 1 , 1 ) # [N,d1,d2,...]->[N*d1*d2*...,1]
113111
114112 # -----------legacy way------------
115113 # idx = target.cpu().long()
@@ -120,19 +118,17 @@ def forward(self, logit, target):
120118 # pt = (one_hot_key * logit).sum(1) + epsilon
121119
122120 # ----------memory saving way--------
123- pt = logit .gather (1 , target ).view (- 1 ) + self .eps # avoid apply
121+ pt = logit .gather (1 , target ).view (- 1 ) + self .eps # avoid apply
124122 logpt = pt .log ()
125123
126124 if self .alpha .device != logpt .device :
127125 alpha = self .alpha .to (logpt .device )
128- alpha_class = alpha .gather (0 ,target .view (- 1 ))
129- logpt = alpha_class * logpt
126+ alpha_class = alpha .gather (0 , target .view (- 1 ))
127+ logpt = alpha_class * logpt
130128 loss = - 1 * torch .pow (torch .sub (1.0 , pt ), self .gamma ) * logpt
131129
132130 if self .size_average :
133131 loss = loss .mean ()
134132 else :
135133 loss = loss .sum ()
136134 return loss
137-
138-
0 commit comments