Skip to content

Commit f484a5f

Browse files
gchanansoumith
authored andcommitted
Implement LogSoftmax double backwards (pytorch#2270)
1 parent aebec91 commit f484a5f

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

test/common_nn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,12 @@
101101
module_name='LogSoftmax',
102102
input_size=(10, 20),
103103
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
104-
check_gradgrad=False,
105104
),
106105
dict(
107106
module_name='LogSoftmax',
108107
input_size=(1, 3, 10, 20),
109108
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
110109
desc='multiparam',
111-
check_gradgrad=False,
112110
),
113111
dict(
114112
module_name='ELU',

torch/nn/_functions/thnn/auto_double_backwards.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ def leakyrelu_double_backwards(ctx, ggI):
3939
return gI, ggO, None, None, None
4040

4141

42+
def logsoftmax_double_backwards(ctx, ggI):
43+
t = ctx.saved_variables
44+
gO, output = t[1], t[2]
45+
46+
output_exp = output.exp()
47+
gO_sum = gO.sum(dim=1, keepdim=True)
48+
ggI_output_exp = ggI * output_exp
49+
ggI_output_exp_sum = ggI_output_exp.sum(dim=1, keepdim=True)
50+
51+
gI = output_exp * gO_sum * ggI_output_exp_sum - ggI_output_exp * gO_sum
52+
ggO = ggI - ggI_output_exp_sum
53+
54+
return gI, ggO, None, None, None, None
55+
56+
4257
def softmax_double_backwards(ctx, ggI):
4358
t = ctx.saved_variables
4459
gO, output = t[1], t[2]
@@ -126,6 +141,7 @@ def nllloss_double_backwards(ctx, ggI):
126141
'ELU': elu_double_backwards,
127142
'Hardtanh': hardtanh_double_backwards,
128143
'LeakyReLU': leakyrelu_double_backwards,
144+
'LogSoftmax': logsoftmax_double_backwards,
129145
'Softmax': softmax_double_backwards,
130146
'Threshold': threshold_double_backwards,
131147
'L1Loss': l1loss_double_backwards,

0 commit comments

Comments
 (0)