Skip to content

Commit 19d4c37

Browse files
committed
Implement MSELoss double backward.
1 parent 7875c02 commit 19d4c37

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

test/common_nn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,14 @@
287287
input=torch.randn(2, 3, 4, 5),
288288
target=torch.randn(2, 3, 4, 5),
289289
reference_fn=lambda i, t, _: (i - t).abs().pow(2).sum() / i.numel(),
290-
check_gradgrad=False,
290+
),
291+
dict(
292+
module_name='MSELoss',
293+
constructor_args=(False,),
294+
input=torch.randn(2, 3, 4, 5),
295+
target=torch.randn(2, 3, 4, 5),
296+
reference_fn=lambda i, t, _: (i - t).abs().pow(2).sum(),
297+
desc='no_size_average',
291298
),
292299
dict(
293300
module_name='BCELoss',

torch/nn/_functions/thnn/auto_double_backwards.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,17 @@ def threshold_double_backwards(ctx, ggI):
165165
return gI, ggO, None, None, None
166166

167167

168+
def mseloss_double_backwards(ctx, ggI):
169+
size_average = ctx.additional_args[0]
170+
input, target, gO = ctx.saved_variables
171+
div_factor = input.nelement() if size_average else 1
172+
173+
gI = ggI * (gO * 2. / div_factor).expand_as(input)
174+
ggO = (ggI * (input - target)).sum() * (2. / div_factor)
175+
176+
return gI, None, ggO, None, None
177+
178+
168179
def l1loss_double_backwards(ctx, ggI):
169180
size_average = ctx.additional_args[0]
170181
input, target, grad_output = ctx.saved_variables
@@ -227,6 +238,7 @@ def nllloss_double_backwards(ctx, ggI):
227238
'Softshrink': softshrink_double_backwards,
228239
'Threshold': threshold_double_backwards,
229240
'L1Loss': l1loss_double_backwards,
241+
'MSELoss': mseloss_double_backwards,
230242
'NLLLoss': nllloss_double_backwards,
231243
'NLLLoss2d': nllloss_double_backwards,
232244
}

0 commit comments

Comments
 (0)