Skip to content

Commit 1efe387

Browse files
committed
Implement KLDivLoss double backwards.
1 parent 5106ce6 commit 1efe387

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

test/common_nn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,13 @@
280280
module_name='KLDivLoss',
281281
input=torch.rand(10, 10).log(),
282282
target=torch.rand(10, 10),
283-
check_gradgrad=False,
283+
),
284+
dict(
285+
module_name='KLDivLoss',
286+
constructor_args=(False,),
287+
input=torch.rand(10, 10).log(),
288+
target=torch.rand(10, 10),
289+
desc='no_size_average',
284290
),
285291
dict(
286292
module_name='MSELoss',

torch/nn/_functions/thnn/auto_double_backwards.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,13 @@ def threshold_double_backwards(ctx, ggI):
165165
return gI, ggO, None, None, None
166166

167167

168-
def mseloss_double_backwards(ctx, ggI):
168+
def klddivloss_double_backwards(ctx, ggI):
169169
size_average = ctx.additional_args[0]
170170
input, target, gO = ctx.saved_variables
171171
div_factor = input.nelement() if size_average else 1
172172

173-
gI = ggI * (gO * 2. / div_factor).expand_as(input)
174-
ggO = (ggI * (input - target)).sum() * (2. / div_factor)
173+
gI = None
174+
ggO = (ggI * target).sum() / -div_factor
175175

176176
return gI, None, ggO, None, None
177177

@@ -189,6 +189,17 @@ def l1loss_double_backwards(ctx, ggI):
189189
return gI, None, ggO, None, None
190190

191191

192+
def mseloss_double_backwards(ctx, ggI):
193+
size_average = ctx.additional_args[0]
194+
input, target, gO = ctx.saved_variables
195+
div_factor = input.nelement() if size_average else 1
196+
197+
gI = ggI * (gO * 2. / div_factor).expand_as(input)
198+
ggO = (ggI * (input - target)).sum() * (2. / div_factor)
199+
200+
return gI, None, ggO, None, None
201+
202+
192203
def nllloss_double_backwards(ctx, ggI):
193204
t = ctx.saved_variables
194205
target = t[1]
@@ -237,7 +248,7 @@ def smoothl1loss_double_backwards(ctx, ggI):
237248
large_error_neg_mask = (((input_sub_target <= 0) + large_error_mask) == 2).type_as(ggI)
238249
small_error_mask = small_error_mask.type_as(ggI)
239250

240-
gI = 1. / div_factor * small_error_mask * ggI * gO
251+
gI = small_error_mask * ggI * gO / div_factor
241252
ggO = (ggI * (input_sub_target * small_error_mask + large_error_pos_mask - large_error_neg_mask)).sum() / div_factor
242253

243254
return gI, None, ggO, None, None, None
@@ -254,6 +265,7 @@ def smoothl1loss_double_backwards(ctx, ggI):
254265
'Softplus': softplus_double_backwards,
255266
'Softshrink': softshrink_double_backwards,
256267
'Threshold': threshold_double_backwards,
268+
'KLDivLoss': klddivloss_double_backwards,
257269
'L1Loss': l1loss_double_backwards,
258270
'MSELoss': mseloss_double_backwards,
259271
'NLLLoss': nllloss_double_backwards,

0 commit comments

Comments
 (0)