@@ -102,7 +102,7 @@ def __init__(self, alpha=0.3, beta=0.7, ignore_index=None, reduction='mean'):
102102 """Dice loss of binary class
103103 Args:
104104 alpha: controls the penalty for false positives.
105- beta: penalty for false negative.
105+ beta: penalty for false negative. Larger beta weigh recall higher
106106 ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
107107 reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
108108 Shapes:
@@ -117,29 +117,31 @@ def __init__(self, alpha=0.3, beta=0.7, ignore_index=None, reduction='mean'):
117117 self .alpha = alpha
118118 self .beta = beta
119119 self .ignore_index = ignore_index
120- self .epsilon = 1e-6
120+ self .smooth = 10
121121 self .reduction = reduction
122122 s = self .beta + self .alpha
123- if sum != 1 :
123+ if s != 1 :
124124 self .beta = self .beta / s
125125 self .alpha = self .alpha / s
126126
127- def forward (self , output , target ):
127+ def forward (self , output , target , mask = None ):
128128 batch_size = output .size (0 )
129-
129+ bg_target = 1 - target
130130 if self .ignore_index is not None :
131131 valid_mask = (target != self .ignore_index ).float ()
132132 output = output .float ().mul (valid_mask ) # can not use inplace for bp
133133 target = target .float ().mul (valid_mask )
134+ bg_target = bg_target .float ().mul (valid_mask )
134135
135136 output = torch .sigmoid (output ).view (batch_size , - 1 )
136137 target = target .view (batch_size , - 1 )
138+ bg_target = bg_target .view (batch_size , - 1 )
137139
138140 P_G = torch .sum (output * target , 1 ) # TP
139- P_NG = torch .sum (output * ( 1 - target ) , 1 ) # FP
141+ P_NG = torch .sum (output * bg_target , 1 ) # FP
140142 NP_G = torch .sum ((1 - output ) * target , 1 ) # FN
141143
142- tversky_index = P_G / (P_G + self .alpha * P_NG + self .beta * NP_G + self .epsilon )
144+ tversky_index = P_G / (P_G + self .alpha * P_NG + self .beta * NP_G + self .smooth )
143145
144146 loss = 1. - tversky_index
145147 # target_area = torch.sum(target_label, 1)
0 commit comments