@@ -91,6 +91,25 @@ def softmax_double_backwards(ctx, ggI):
9191 return gI , ggO , None , None , None , None
9292
9393
94+ def softplus_double_backwards (ctx , ggI ):
95+ t = ctx .saved_variables
96+ input , gO , output = t [0 ], t [1 ], t [2 ]
97+ beta , threshold = ctx .additional_args [0 ], ctx .additional_args [1 ]
98+
99+ input_beta = input * beta
100+ above_threshold = ((input_beta ) > threshold ).type_as (ggI )
101+ below_threshold = ((input_beta ) <= threshold ).type_as (ggI )
102+
103+ exp_output_beta = (output * beta ).exp ()
104+ first_deriv = (exp_output_beta - 1 ) / exp_output_beta
105+ first_deriv_below_threshold = first_deriv * below_threshold
106+
107+ gI = ggI * gO * first_deriv_below_threshold * beta / exp_output_beta
108+ ggO = ggI * (above_threshold + first_deriv_below_threshold )
109+
110+ return gI , ggO , None , None , None , None
111+
112+
94113def threshold_double_backwards (ctx , ggI ):
95114 t = ctx .saved_variables
96115 input = t [0 ]
@@ -158,6 +177,7 @@ def nllloss_double_backwards(ctx, ggI):
158177 'LeakyReLU' : leakyrelu_double_backwards ,
159178 'LogSoftmax' : logsoftmax_double_backwards ,
160179 'Softmax' : softmax_double_backwards ,
180+ 'Softplus' : softplus_double_backwards ,
161181 'Threshold' : threshold_double_backwards ,
162182 'L1Loss' : l1loss_double_backwards ,
163183 'NLLLoss' : nllloss_double_backwards ,
0 commit comments