Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit ef48f43

Browse files
committed
add batch loss
1 parent ce7ac2f commit ef48f43

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

DiceLoss/dice_loss.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ class BinaryDiceLoss(nn.Module):
4040
Exception if unexpected reduction
4141
"""
4242

43-
def __init__(self, ignore_index=None, reduction='mean'):
43+
def __init__(self, ignore_index=None, reduction='mean',**kwargs):
4444
super(BinaryDiceLoss, self).__init__()
45-
self.smooth = 1
45+
self.smooth = 1 # suggest set a large number when TP is large
4646
self.ignore_index = ignore_index
4747
self.reduction = reduction
48+
self.batch_dice = False # treat a large map when True
49+
if 'batch_loss' in kwargs.keys():
50+
self.batch_dice = kwargs['batch_loss']
4851

4952
def forward(self, output, target, use_sigmoid=True):
5053
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
@@ -56,8 +59,12 @@ def forward(self, output, target, use_sigmoid=True):
5659
output = output.mul(validmask) # can not use inplace for bp
5760
target = target.float().mul(validmask)
5861

59-
output = output.contiguous().view(output.shape[0], -1)
60-
target = target.contiguous().view(target.shape[0], -1).float()
62+
dim0= output.shape[0]
63+
if self.batch_dice:
64+
dim0 = 1
65+
66+
output = output.contiguous().view(dim0, -1)
67+
target = target.contiguous().view(dim0, -1).float()
6168

6269
num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
6370
den = torch.sum(output.pow(2) + target.pow(2), dim=1) + self.smooth

0 commit comments

Comments
 (0)