Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit 841a6df

Browse files
committed
update wbcelogloss
1 parent da19001 commit 841a6df

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

DiceLoss/dice_loss.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,23 @@ def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
133133
# self.bce = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction)
134134

135135
def forward(self, output, target):
136-
output = torch.sigmoid(output)
137-
138-
batch_size = output.size(0)
139-
output = output.view(batch_size, -1)
140-
target = target.view(batch_size, -1)
136+
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
141137

142-
# pos_log = torch.log(output)
143-
# neg_log = torch.log(1.0 - output)
144138
if self.ignore_index is not None:
145139
valid_mask = (target != self.ignore_index).float()
146140
output = output.mul(valid_mask) # can not use inplace for bp
147141
target = target.float().mul(valid_mask)
148-
# pos_log = pos_log.mul(valid_mask)
149-
# neg_log = neg_log.mul(valid_mask)
150142

143+
batch_size = output.size(0)
144+
output = output.view(batch_size, -1)
145+
target = target.view(batch_size, -1)
146+
147+
output = torch.sigmoid(output)
151148
# avoid `nan` loss
152-
eps=1e-6
149+
eps = 1e-6
153150
output = torch.clamp(output, min=eps, max=1.0 - eps)
154151
target = torch.clamp(target, min=eps, max=1.0 - eps)
152+
155153
# loss = self.bce(output, target)
156154
loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output)))
157155
if self.reduction == 'mean':

0 commit comments

Comments
 (0)