@@ -165,13 +165,13 @@ def threshold_double_backwards(ctx, ggI):
165165 return gI , ggO , None , None , None
166166
167167
168- def mseloss_double_backwards (ctx , ggI ):
168+ def klddivloss_double_backwards (ctx , ggI ):
169169 size_average = ctx .additional_args [0 ]
170170 input , target , gO = ctx .saved_variables
171171 div_factor = input .nelement () if size_average else 1
172172
173- gI = ggI * ( gO * 2. / div_factor ). expand_as ( input )
174- ggO = (ggI * ( input - target )) .sum () * ( 2. / div_factor )
173+ gI = None
174+ ggO = (ggI * target ).sum () / - div_factor
175175
176176 return gI , None , ggO , None , None
177177
@@ -189,6 +189,17 @@ def l1loss_double_backwards(ctx, ggI):
189189 return gI , None , ggO , None , None
190190
191191
192+ def mseloss_double_backwards (ctx , ggI ):
193+ size_average = ctx .additional_args [0 ]
194+ input , target , gO = ctx .saved_variables
195+ div_factor = input .nelement () if size_average else 1
196+
197+ gI = ggI * (gO * 2. / div_factor ).expand_as (input )
198+ ggO = (ggI * (input - target )).sum () * (2. / div_factor )
199+
200+ return gI , None , ggO , None , None
201+
202+
192203def nllloss_double_backwards (ctx , ggI ):
193204 t = ctx .saved_variables
194205 target = t [1 ]
@@ -237,7 +248,7 @@ def smoothl1loss_double_backwards(ctx, ggI):
237248 large_error_neg_mask = (((input_sub_target <= 0 ) + large_error_mask ) == 2 ).type_as (ggI )
238249 small_error_mask = small_error_mask .type_as (ggI )
239250
240- gI = 1. / div_factor * small_error_mask * ggI * gO
251+ gI = small_error_mask * ggI * gO / div_factor
241252 ggO = (ggI * (input_sub_target * small_error_mask + large_error_pos_mask - large_error_neg_mask )).sum () / div_factor
242253
243254 return gI , None , ggO , None , None , None
@@ -254,6 +265,7 @@ def smoothl1loss_double_backwards(ctx, ggI):
254265 'Softplus' : softplus_double_backwards ,
255266 'Softshrink' : softshrink_double_backwards ,
256267 'Threshold' : threshold_double_backwards ,
268+ 'KLDivLoss' : klddivloss_double_backwards ,
257269 'L1Loss' : l1loss_double_backwards ,
258270 'MSELoss' : mseloss_double_backwards ,
259271 'NLLLoss' : nllloss_double_backwards ,
0 commit comments