Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit da19001

Browse files
committed
update wbce_dice_loss
1 parent b0ecd6a commit da19001

File tree

1 file changed

+3
-18
lines changed

1 file changed

+3
-18
lines changed

DiceLoss/dice_loss.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def forward(self, output, target):
112112
loss = total_loss / (target.size(1) - len(self.ignore_index))
113113
return loss
114114

115+
115116
class WBCEWithLogitLoss(nn.Module):
116117
def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
117118
"""
@@ -178,11 +179,11 @@ def __init__(self, alpha=1.0, weight=1.0, ignore_index=None, reduction='mean'):
178179
"""
179180
super(WBCE_DiceLoss, self).__init__()
180181
assert reduction in ['none', 'mean', 'sum']
181-
assert (alpha >= 0 and alpha <= 1), '`alpha` should in [0,1]'
182+
assert 0 <= alpha <= 1, '`alpha` should in [0,1]'
182183
self.alpha = alpha
183184
self.ignore_index = ignore_index
184185
self.reduction = reduction
185-
self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction)
186+
self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction, general=True)
186187
self.wbce = WBCEWithLogitLoss(weight=weight, ignore_index=ignore_index, reduction=reduction)
187188
self.dice_loss = None
188189
self.wbce_loss = None
@@ -191,22 +192,6 @@ def forward(self, output, target):
191192
self.dice_loss = self.dice(output, target)
192193
self.wbce_loss = self.wbce(output, target)
193194
loss = self.alpha * self.wbce_loss + self.dice_loss
194-
195-
# if self.ignore_index is not None:
196-
# mask = (target != self.ignore_index).float()
197-
# output = output.mul(mask) # can not use inplace for bp
198-
# target = target.float().mul(mask)
199-
# bce_loss = self.bce(output, target)
200-
# loss = self.alpha * bce_loss + (1.0 - self.alpha) * dice_loss
201-
#
202-
# if self.reduction == 'mean':
203-
# return loss.mean()
204-
# elif self.reduction == 'sum':
205-
# return loss.sum()
206-
# elif self.reduction == 'none':
207-
# return loss
208-
# else:
209-
# raise Exception('Unexpected reduction {}'.format(self.reduction))
210195
return loss
211196

212197

0 commit comments

Comments
 (0)