Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit b0ecd6a

Browse files
committed
fix bug
1 parent eaf2c15 commit b0ecd6a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

DiceLoss/dice_loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,11 @@ def forward(self, output, target):
148148
# neg_log = neg_log.mul(valid_mask)
149149

150150
# avoid `nan` loss
151-
output = torch.clamp(output, min=1e-8, max=1.0 - 1e-8)
152-
151+
eps=1e-6
152+
output = torch.clamp(output, min=eps, max=1.0 - eps)
153+
target = torch.clamp(target, min=eps, max=1.0 - eps)
153154
# loss = self.bce(output, target)
154-
loss = -self.weight * target.mul(output) - ((1.0 - target).mul(1.0 - output))
155+
loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output)))
155156
if self.reduction == 'mean':
156157
loss = torch.mean(loss)
157158
elif self.reduction == 'sum':

0 commit comments

Comments
 (0)