@@ -113,21 +113,25 @@ def forward(self, output, target):
113113
114114
115115class WBCEWithLogitLoss (nn .Module ):
116- def __init__ (self , weight = 1.0 , ignore_index = None , reduction = 'mean' ):
117- """
118- Weight Binary Cross Entropy
119- Args:
116+ """
117+ Weighted Binary Cross Entropy.
118+ `WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)`
119+ To decrease the number of false negatives, set β>1.
120+ To decrease the number of false positives, set β<1.
121+ Args:
120122 @param weight: positive sample weight
121123 Shapes:
122124 output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied
123125 target: A tensor of shape same with output
124- """
126+ """
127+ def __init__ (self , weight = 1.0 , ignore_index = None , reduction = 'mean' ):
125128 super (WBCEWithLogitLoss , self ).__init__ ()
126129 assert reduction in ['none' , 'mean' , 'sum' ]
127130 self .ignore_index = ignore_index
128131 weight = float (weight )
129132 self .weight = weight
130133 self .reduction = reduction
134+ self .smooth = 0.01
131135
132136 def forward (self , output , target ):
133137 assert output .shape [0 ] == target .shape [0 ], "output & target batch size don't match"
@@ -145,7 +149,8 @@ def forward(self, output, target):
145149 # avoid `nan` loss
146150 eps = 1e-6
147151 output = torch .clamp (output , min = eps , max = 1.0 - eps )
148- target = torch .clamp (target , min = eps , max = 1.0 - eps )
152+ # soft label
153+ target = torch .clamp (target , min = self .smooth , max = 1.0 - self .smooth )
149154
150155 # loss = self.bce(output, target)
151156 loss = - self .weight * target .mul (torch .log (output )) - ((1.0 - target ).mul (torch .log (1.0 - output )))
0 commit comments