Skip to content

Commit 8d38c0e

Browse files
committed
Implement Softplus double backwards.
1 parent ea9a782 commit 8d38c0e

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

test/common_nn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,21 @@
139139
module_name='Softplus',
140140
input_size=(10, 20),
141141
reference_fn=lambda i, _: torch.log(1 + torch.exp(i)),
142-
check_gradgrad=False,
143142
),
144143
dict(
145144
module_name='Softplus',
146145
constructor_args=(2,),
147146
input_size=(10, 20),
148147
reference_fn=lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
149148
desc='beta',
150-
check_gradgrad=False,
149+
),
150+
dict(
151+
module_name='Softplus',
152+
constructor_args=(2, -100),
153+
input_size=(10, 20),
154+
reference_fn=(lambda i, _: ((i * 2) > -100).type_as(i) * i +
155+
((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
156+
desc='beta_threshold',
151157
),
152158
dict(
153159
module_name='Softshrink',

torch/nn/_functions/thnn/auto_double_backwards.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,25 @@ def softmax_double_backwards(ctx, ggI):
9191
return gI, ggO, None, None, None, None
9292

9393

94+
def softplus_double_backwards(ctx, ggI):
95+
t = ctx.saved_variables
96+
input, gO, output = t[0], t[1], t[2]
97+
beta, threshold = ctx.additional_args[0], ctx.additional_args[1]
98+
99+
input_beta = input * beta
100+
above_threshold = ((input_beta) > threshold).type_as(ggI)
101+
below_threshold = ((input_beta) <= threshold).type_as(ggI)
102+
103+
exp_output_beta = (output * beta).exp()
104+
first_deriv = (exp_output_beta - 1) / exp_output_beta
105+
first_deriv_below_threshold = first_deriv * below_threshold
106+
107+
gI = ggI * gO * first_deriv_below_threshold * beta / exp_output_beta
108+
ggO = ggI * (above_threshold + first_deriv_below_threshold)
109+
110+
return gI, ggO, None, None, None, None
111+
112+
94113
def threshold_double_backwards(ctx, ggI):
95114
t = ctx.saved_variables
96115
input = t[0]
@@ -158,6 +177,7 @@ def nllloss_double_backwards(ctx, ggI):
158177
'LeakyReLU': leakyrelu_double_backwards,
159178
'LogSoftmax': logsoftmax_double_backwards,
160179
'Softmax': softmax_double_backwards,
180+
'Softplus': softplus_double_backwards,
161181
'Threshold': threshold_double_backwards,
162182
'L1Loss': l1loss_double_backwards,
163183
'NLLLoss': nllloss_double_backwards,

0 commit comments

Comments
 (0)