88
99class FocalBinaryTverskyLoss (Function ):
1010
11- def __init__ (ctx , alpha = 0.5 , beta = 0.5 , gamma = 1.0 , reduction = 'mean' ):
12- """
13-
14- :param alpha: controls the penalty for false positives.
15- :param beta: penalty for false negative.
16- :param gamma : focal coefficient range[1,3]
17- :param reduction: return mode
18- Notes:
19- alpha = beta = 0.5 -> dice coeff
20- alpha = beta = 1 -> tanimoto coeff
21- alpha + beta = 1 -> F beta coeff
22- add focal index -> loss=(1-T_index)**(1/gamma)
23-
24- """
25- ctx .alpha = alpha
26- ctx .beta = beta
27- ctx .epsilon = 1e-6
28- ctx .reduction = reduction
29- ctx .gamma = gamma
30- s = ctx .beta + ctx .alpha
31- if sum != 1 :
32- ctx .beta = ctx .beta / s
33- ctx .alpha = ctx .alpha / s
3411
35- # @staticmethod
12+ @staticmethod
3613 def forward (ctx , input , target ):
14+ _alpha = 0.5
15+ _beta = 0.5
16+ _gamma = 1.0
17+ _epsilon = 1e-6
18+ _reduction = 'mean'
19+
3720 batch_size = input .size (0 )
3821 _ , input_label = input .max (1 )
3922
@@ -49,13 +32,13 @@ def forward(ctx, input, target):
4932 ctx .P_NG = torch .sum (input_label * (1 - target_label ), 1 ) # FP
5033 ctx .NP_G = torch .sum ((1 - input_label ) * target_label , 1 ) # FN
5134
52- index = ctx .P_G / (ctx .P_G + ctx . alpha * ctx .P_NG + ctx . beta * ctx .NP_G + ctx . epsilon )
53- loss = torch .pow ((1 - index ), 1 / ctx . gamma )
35+ index = ctx .P_G / (ctx .P_G + _alpha * ctx .P_NG + _beta * ctx .NP_G + _epsilon )
36+ loss = torch .pow ((1 - index ), 1 / _gamma )
5437 # target_area = torch.sum(target_label, 1)
5538 # loss[target_area == 0] = 0
56- if ctx . reduction == 'none' :
39+ if _reduction == 'none' :
5740 loss = loss
58- elif ctx . reduction == 'sum' :
41+ elif _reduction == 'sum' :
5942 loss = torch .sum (loss )
6043 else :
6144 loss = torch .mean (loss )
@@ -72,24 +55,30 @@ def backward(ctx, grad_out):
7255 = 2*P_G
7356 (dT_loss/d_p0)=
7457 """
58+ _alpha = 0.5
59+ _beta = 0.5
60+ _gamma = 1.0
61+ _reduction = 'mean'
62+ _epsilon = 1e-6
63+
7564 inputs , target = ctx .saved_tensors
7665 inputs = inputs .float ()
7766 target = target .float ()
7867 batch_size = inputs .size (0 )
79- sum = ctx .P_G + ctx . alpha * ctx .P_NG + ctx . beta * ctx .NP_G + ctx . epsilon
68+ sum = ctx .P_G + _alpha * ctx .P_NG + _beta * ctx .NP_G + _epsilon
8069 P_G = ctx .P_G .view (batch_size , 1 , 1 , 1 , 1 )
8170 if inputs .dim () == 5 :
8271 sum = sum .view (batch_size , 1 , 1 , 1 , 1 )
8372 elif inputs .dim () == 4 :
8473 sum = sum .view (batch_size , 1 , 1 , 1 )
8574 P_G = ctx .P_G .view (batch_size , 1 , 1 , 1 )
86- sub = (ctx . alpha * (1 - target ) + target ) * P_G
75+ sub = (_alpha * (1 - target ) + target ) * P_G
8776
88- dL_dT = (1 / ctx . gamma ) * torch .pow ((P_G / sum ), (1 / ctx . gamma - 1 ))
77+ dL_dT = (1 / _gamma ) * torch .pow ((P_G / sum ), (1 / _gamma - 1 ))
8978 dT_dp0 = - 2 * (target / sum - sub / sum / sum )
9079 dL_dp0 = dL_dT * dT_dp0
9180
92- dT_dp1 = ctx . beta * (1 - target ) * P_G / sum / sum
81+ dT_dp1 = _beta * (1 - target ) * P_G / sum / sum
9382 dL_dp1 = dL_dT * dT_dp1
9483 grad_input = torch .cat ((dL_dp1 , dL_dp0 ), dim = 1 )
9584 # grad_input = torch.cat((grad_out.item() * dL_dp0, dL_dp0 * grad_out.item()), dim=1)
@@ -152,4 +141,4 @@ def forward(self, output, target, mask=None):
152141 loss = torch .sum (loss )
153142 else :
154143 loss = torch .mean (loss )
155- return loss
144+ return loss
0 commit comments