@@ -133,25 +133,23 @@ def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
133133 # self.bce = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction)
134134
135135 def forward (self , output , target ):
136- output = torch .sigmoid (output )
137-
138- batch_size = output .size (0 )
139- output = output .view (batch_size , - 1 )
140- target = target .view (batch_size , - 1 )
136+ assert output .shape [0 ] == target .shape [0 ], "output & target batch size don't match"
141137
142- # pos_log = torch.log(output)
143- # neg_log = torch.log(1.0 - output)
144138 if self .ignore_index is not None :
145139 valid_mask = (target != self .ignore_index ).float ()
146140 output = output .mul (valid_mask ) # can not use inplace for bp
147141 target = target .float ().mul (valid_mask )
148- # pos_log = pos_log.mul(valid_mask)
149- # neg_log = neg_log.mul(valid_mask)
150142
143+ batch_size = output .size (0 )
144+ output = output .view (batch_size , - 1 )
145+ target = target .view (batch_size , - 1 )
146+
147+ output = torch .sigmoid (output )
151148 # avoid `nan` loss
152- eps = 1e-6
149+ eps = 1e-6
153150 output = torch .clamp (output , min = eps , max = 1.0 - eps )
154151 target = torch .clamp (target , min = eps , max = 1.0 - eps )
152+
155153 # loss = self.bce(output, target)
156154 loss = - self .weight * target .mul (torch .log (output )) - ((1.0 - target ).mul (torch .log (1.0 - output )))
157155 if self .reduction == 'mean' :
0 commit comments