@@ -18,51 +18,51 @@ class BinaryFocalLoss(nn.Module):
1818 balance_index: (int) balance class index, should be specific when alpha is float
1919 """
2020
21- def __init__ (self , alpha = [ 1.0 , 1.0 ], gamma = 2 , ignore_index = None , reduction = 'mean' ):
21+ def __init__ (self , alpha = 3 , gamma = 2 , ignore_index = None , reduction = 'mean' , ** kwargs ):
2222 super (BinaryFocalLoss , self ).__init__ ()
23- if alpha is None :
24- alpha = [0.25 , 0.75 ]
2523 self .alpha = alpha
2624 self .gamma = gamma
27- self .smooth = 1e-6
25+ self .smooth = 1e-6 # set '1e-4' when train with FP16
2826 self .ignore_index = ignore_index
2927 self .reduction = reduction
3028
3129 assert self .reduction in ['none' , 'mean' , 'sum' ]
3230
33- if self .alpha is None :
34- self .alpha = torch .ones (2 )
35- elif isinstance (self .alpha , (list , np .ndarray )):
36- self .alpha = np .asarray (self .alpha )
37- self .alpha = np .reshape (self .alpha , (2 ))
38- assert self .alpha .shape [0 ] == 2 , \
39- 'the `alpha` shape is not match the number of class'
40- elif isinstance (self .alpha , (float , int )):
41- self .alpha = np .asarray ([self .alpha , 1.0 - self .alpha ], dtype = np .float ).view (2 )
31+ # if self.alpha is None:
32+ # self.alpha = torch.ones(2)
33+ # elif isinstance(self.alpha, (list, np.ndarray)):
34+ # self.alpha = np.asarray(self.alpha)
35+ # self.alpha = np.reshape(self.alpha, (2))
36+ # assert self.alpha.shape[0] == 2, \
37+ # 'the `alpha` shape is not match the number of class'
38+ # elif isinstance(self.alpha, (float, int)):
39+ # self.alpha = np.asarray([self.alpha, 1.0 - self.alpha], dtype=np.float).view(2)
4240
43- else :
44- raise TypeError ('{} not supported' .format (type (self .alpha )))
41+ # else:
42+ # raise TypeError('{} not supported'.format(type(self.alpha)))
4543
4644 def forward (self , output , target ):
4745 prob = torch .sigmoid (output )
4846 prob = torch .clamp (prob , self .smooth , 1.0 - self .smooth )
4947
48+ valid_mask = None
49+ if self .ignore_index is not None :
50+ valid_mask = (target != self .ignore_index ).float ()
51+
5052 pos_mask = (target == 1 ).float ()
5153 neg_mask = (target == 0 ).float ()
54+ if valid_mask is not None :
55+ pos_mask = pos_mask * valid_mask
56+ neg_mask = neg_mask * valid_mask
57+
58+ pos_weight = (pos_mask * torch .pow (1 - prob , self .gamma )).detach ()
59+ pos_loss = - torch .sum (pos_weight * torch .log (prob )) / (torch .sum (pos_weight ) + 1e-4 )
60+
61+
62+ neg_weight = (neg_mask * torch .pow (prob , self .gamma )).detach ()
63+ neg_loss = - self .alpha * torch .sum (neg_weight * F .logsigmoid (- logit )) / (torch .sum (neg_weight ) + 1e-4 )
64+ loss = pos_loss + neg_loss
5265
53- pos_loss = - self .alpha [0 ] * torch .pow (torch .sub (1.0 , prob ), self .gamma ) * torch .log (prob ) * pos_mask
54- neg_loss = - self .alpha [1 ] * torch .pow (prob , self .gamma ) * \
55- torch .log (torch .sub (1.0 , prob )) * neg_mask
56-
57- neg_loss = neg_loss .sum ()
58- pos_loss = pos_loss .sum ()
59- num_pos = pos_mask .view (pos_mask .size (0 ), - 1 ).sum ()
60- num_neg = neg_mask .view (neg_mask .size (0 ), - 1 ).sum ()
61-
62- if num_pos == 0 :
63- loss = neg_loss
64- else :
65- loss = pos_loss / num_pos + neg_loss / num_neg
6666 return loss
6767
6868
0 commit comments