@@ -112,6 +112,7 @@ def forward(self, output, target):
112112 loss = total_loss / (target .size (1 ) - len (self .ignore_index ))
113113 return loss
114114
115+
115116class WBCEWithLogitLoss (nn .Module ):
116117 def __init__ (self , weight = 1.0 , ignore_index = None , reduction = 'mean' ):
117118 """
@@ -178,11 +179,11 @@ def __init__(self, alpha=1.0, weight=1.0, ignore_index=None, reduction='mean'):
178179 """
179180 super (WBCE_DiceLoss , self ).__init__ ()
180181 assert reduction in ['none' , 'mean' , 'sum' ]
181- assert ( alpha >= 0 and alpha <= 1 ) , '`alpha` should in [0,1]'
182+ assert 0 <= alpha <= 1 , '`alpha` should in [0,1]'
182183 self .alpha = alpha
183184 self .ignore_index = ignore_index
184185 self .reduction = reduction
185- self .dice = BinaryDiceLoss (ignore_index = ignore_index , reduction = reduction )
186+ self .dice = BinaryDiceLoss (ignore_index = ignore_index , reduction = reduction , general = True )
186187 self .wbce = WBCEWithLogitLoss (weight = weight , ignore_index = ignore_index , reduction = reduction )
187188 self .dice_loss = None
188189 self .wbce_loss = None
@@ -191,22 +192,6 @@ def forward(self, output, target):
191192 self .dice_loss = self .dice (output , target )
192193 self .wbce_loss = self .wbce (output , target )
193194 loss = self .alpha * self .wbce_loss + self .dice_loss
194-
195- # if self.ignore_index is not None:
196- # mask = (target != self.ignore_index).float()
197- # output = output.mul(mask) # can not use inplace for bp
198- # target = target.float().mul(mask)
199- # bce_loss = self.bce(output, target)
200- # loss = self.alpha * bce_loss + (1.0 - self.alpha) * dice_loss
201- #
202- # if self.reduction == 'mean':
203- # return loss.mean()
204- # elif self.reduction == 'sum':
205- # return loss.sum()
206- # elif self.reduction == 'none':
207- # return loss
208- # else:
209- # raise Exception('Unexpected reduction {}'.format(self.reduction))
210195 return loss
211196
212197
0 commit comments