Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit 3dfeb8c

Browse files
committed
update dice
1 parent dd27651 commit 3dfeb8c

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

DiceLoss/dice_loss.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,25 @@ def forward(self, output, target):
113113

114114

115115
class WBCEWithLogitLoss(nn.Module):
116-
def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
117-
"""
118-
Weight Binary Cross Entropy
119-
Args:
116+
"""
117+
Weighted Binary Cross Entropy.
118+
`WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)`
119+
To decrease the number of false negatives, set β>1.
120+
To decrease the number of false positives, set β<1.
121+
Args:
120122
@param weight: positive sample weight
121123
Shapes:
122124
output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied
123125
target: A tensor of shape same with output
124-
"""
126+
"""
127+
def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
125128
super(WBCEWithLogitLoss, self).__init__()
126129
assert reduction in ['none', 'mean', 'sum']
127130
self.ignore_index = ignore_index
128131
weight = float(weight)
129132
self.weight = weight
130133
self.reduction = reduction
134+
self.smooth=0.01
131135

132136
def forward(self, output, target):
133137
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
@@ -145,7 +149,8 @@ def forward(self, output, target):
145149
# avoid `nan` loss
146150
eps = 1e-6
147151
output = torch.clamp(output, min=eps, max=1.0 - eps)
148-
target = torch.clamp(target, min=eps, max=1.0 - eps)
152+
# soft label
153+
target = torch.clamp(target, min=self.smooth, max=1.0 - self.smooth)
149154

150155
# loss = self.bce(output, target)
151156
loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output)))

0 commit comments

Comments
 (0)