11from torch .autograd import Variable
2+ import torch
23
34
45def elu_double_backwards (ctx , ggI ):
@@ -15,6 +16,32 @@ def elu_double_backwards(ctx, ggI):
1516 return gI , ggO , None , None , None , None
1617
1718
19+ def gatedlinear_double_backwards (ctx , ggI ):
20+ input , gO = ctx .saved_variables
21+ dim = ctx .additional_args [0 ]
22+
23+ input_size = input .size (dim ) // 2
24+
25+ first_half = input .narrow (dim , 0 , input_size )
26+ second_half = input .narrow (dim , input_size , input_size )
27+ sig_second_half = second_half .sigmoid ()
28+ one_sub_sig_second_half = 1 - sig_second_half
29+ sig_one_sub_sig = sig_second_half * one_sub_sig_second_half
30+
31+ ggI_first_half = ggI .narrow (dim , 0 , input_size )
32+ ggI_second_half = ggI .narrow (dim , input_size , input_size )
33+ ggI_second_half_times_first_half = ggI_second_half * first_half
34+
35+ gI_first_half = ggI_second_half * gO * sig_one_sub_sig
36+ second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig
37+ gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig
38+ gI = torch .cat ((gI_first_half , gI_second_half ), dim )
39+
40+ ggO = ggI_first_half * sig_second_half + ggI_second_half_times_first_half * sig_one_sub_sig
41+
42+ return gI , ggO , None , None , None
43+
44+
1845def hardshrink_double_backwards (ctx , ggI ):
1946 t = ctx .saved_variables
2047 input = t [0 ]
@@ -189,6 +216,7 @@ def nllloss_double_backwards(ctx, ggI):
189216
190217double_backwards_fns = {
191218 'ELU' : elu_double_backwards ,
219+ 'GatedLinear' : gatedlinear_double_backwards ,
192220 'Hardshrink' : hardshrink_double_backwards ,
193221 'Hardtanh' : hardtanh_double_backwards ,
194222 'LeakyReLU' : leakyrelu_double_backwards ,
0 commit comments