@@ -40,11 +40,14 @@ class BinaryDiceLoss(nn.Module):
4040 Exception if unexpected reduction
4141 """
4242
43- def __init__ (self , ignore_index = None , reduction = 'mean' ):
43+ def __init__ (self , ignore_index = None , reduction = 'mean' , ** kwargs ):
4444 super (BinaryDiceLoss , self ).__init__ ()
45- self .smooth = 1
45+ self .smooth = 1 # suggest set a large number when TP is large
4646 self .ignore_index = ignore_index
4747 self .reduction = reduction
48+ self .batch_dice = False # treat a large map when True
49+ if 'batch_loss' in kwargs .keys ():
50+ self .batch_dice = kwargs ['batch_loss' ]
4851
4952 def forward (self , output , target , use_sigmoid = True ):
5053 assert output .shape [0 ] == target .shape [0 ], "output & target batch size don't match"
@@ -56,8 +59,12 @@ def forward(self, output, target, use_sigmoid=True):
5659 output = output .mul (validmask ) # can not use inplace for bp
5760 target = target .float ().mul (validmask )
5861
59- output = output .contiguous ().view (output .shape [0 ], - 1 )
60- target = target .contiguous ().view (target .shape [0 ], - 1 ).float ()
62+ dim0 = output .shape [0 ]
63+ if self .batch_dice :
64+ dim0 = 1
65+
66+ output = output .contiguous ().view (dim0 , - 1 )
67+ target = target .contiguous ().view (dim0 , - 1 ).float ()
6168
6269 num = 2 * torch .sum (torch .mul (output , target ), dim = 1 ) + self .smooth
6370 den = torch .sum (output .pow (2 ) + target .pow (2 ), dim = 1 ) + self .smooth
0 commit comments