@@ -16,9 +16,9 @@ def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean'):
1616 :param gamma : focal coefficient range[1,3]
1717 :param reduction: return mode
1818 Notes:
19- alpha = beta = 0.5 = > dice coeff
20- alpha = beta = 1 = > tanimoto coeff
21- alpha + beta = 1 = > F beta coeff
19+ alpha = beta = 0.5 - > dice coeff
20+ alpha = beta = 1 - > tanimoto coeff
21+ alpha + beta = 1 - > F beta coeff
2222 add focal index -> loss=(1-T_index)**(1/gamma)
2323
2424 """
@@ -27,10 +27,10 @@ def __init__(ctx, alpha=0.5, beta=0.5, gamma=1.0, reduction='mean'):
2727 ctx .epsilon = 1e-6
2828 ctx .reduction = reduction
2929 ctx .gamma = gamma
30- sum = ctx .beta + ctx .alpha
30+ s = ctx .beta + ctx .alpha
3131 if sum != 1 :
32- ctx .beta = ctx .beta / sum
33- ctx .alpha = ctx .alpha / sum
32+ ctx .beta = ctx .beta / s
33+ ctx .alpha = ctx .alpha / s
3434
3535 # @staticmethod
3636 def forward (ctx , input , target ):
@@ -94,3 +94,60 @@ def backward(ctx, grad_out):
9494 grad_input = torch .cat ((dL_dp1 , dL_dp0 ), dim = 1 )
9595 # grad_input = torch.cat((grad_out.item() * dL_dp0, dL_dp0 * grad_out.item()), dim=1)
9696 return grad_input , None
97+
98+
99+ class BinaryTverskyLossV2 (nn .Module ):
100+
101+ def __init__ (self , alpha = 0.3 , beta = 0.7 , ignore_index = None , reduction = 'mean' ):
102+ """Dice loss of binary class
103+ Args:
104+ alpha: controls the penalty for false positives.
105+ beta: penalty for false negative.
106+ ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
107+ reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
108+ Shapes:
109+ output: A tensor of shape [N, 1,(d,) h, w] without sigmoid activation function applied
110+ target: A tensor of shape same with output
111+ Returns:
112+ Loss tensor according to arg reduction
113+ Raise:
114+ Exception if unexpected reduction
115+ """
116+ super (BinaryTverskyLossV2 , self ).__init__ ()
117+ self .alpha = alpha
118+ self .beta = beta
119+ self .ignore_index = ignore_index
120+ self .epsilon = 1e-6
121+ self .reduction = reduction
122+ s = self .beta + self .alpha
123+ if sum != 1 :
124+ self .beta = self .beta / s
125+ self .alpha = self .alpha / s
126+
127+ def forward (self , output , target ):
128+ batch_size = output .size (0 )
129+
130+ if self .ignore_index is not None :
131+ valid_mask = (target != self .ignore_index ).float ()
132+ output = output .float ().mul (valid_mask ) # can not use inplace for bp
133+ target = target .float ().mul (valid_mask )
134+
135+ output = torch .sigmoid (output ).view (batch_size , - 1 )
136+ target = target .view (batch_size , - 1 )
137+
138+ P_G = torch .sum (output * target , 1 ) # TP
139+ P_NG = torch .sum (output * (1 - target ), 1 ) # FP
140+ NP_G = torch .sum ((1 - output ) * target , 1 ) # FN
141+
142+ tversky_index = P_G / (P_G + self .alpha * P_NG + self .beta * NP_G + self .epsilon )
143+
144+ loss = 1. - tversky_index
145+ # target_area = torch.sum(target_label, 1)
146+ # loss[target_area == 0] = 0
147+ if self .reduction == 'none' :
148+ loss = loss
149+ elif self .reduction == 'sum' :
150+ loss = torch .sum (loss )
151+ else :
152+ loss = torch .mean (loss )
153+ return loss
0 commit comments