Skip to content

Commit 7875c02

Browse files
committed
Implement GLU double backwards.
1 parent 9a243ab commit 7875c02

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

test/test_nn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3864,7 +3864,12 @@ def add_test(test):
38643864
dict(
38653865
module_name='GLU',
38663866
input_size=(5, 6),
3867-
check_gradgrad=False,
3867+
),
3868+
dict(
3869+
module_name='GLU',
3870+
constructor_args=(1,),
3871+
input_size=(5, 6, 7),
3872+
desc='dim'
38683873
),
38693874
]
38703875

torch/nn/_functions/thnn/auto_double_backwards.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch.autograd import Variable
2+
import torch
23

34

45
def elu_double_backwards(ctx, ggI):
@@ -15,6 +16,32 @@ def elu_double_backwards(ctx, ggI):
1516
return gI, ggO, None, None, None, None
1617

1718

19+
def gatedlinear_double_backwards(ctx, ggI):
20+
input, gO = ctx.saved_variables
21+
dim = ctx.additional_args[0]
22+
23+
input_size = input.size(dim) // 2
24+
25+
first_half = input.narrow(dim, 0, input_size)
26+
second_half = input.narrow(dim, input_size, input_size)
27+
sig_second_half = second_half.sigmoid()
28+
one_sub_sig_second_half = 1 - sig_second_half
29+
sig_one_sub_sig = sig_second_half * one_sub_sig_second_half
30+
31+
ggI_first_half = ggI.narrow(dim, 0, input_size)
32+
ggI_second_half = ggI.narrow(dim, input_size, input_size)
33+
ggI_second_half_times_first_half = ggI_second_half * first_half
34+
35+
gI_first_half = ggI_second_half * gO * sig_one_sub_sig
36+
second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig
37+
gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig
38+
gI = torch.cat((gI_first_half, gI_second_half), dim)
39+
40+
ggO = ggI_first_half * sig_second_half + ggI_second_half_times_first_half * sig_one_sub_sig
41+
42+
return gI, ggO, None, None, None
43+
44+
1845
def hardshrink_double_backwards(ctx, ggI):
1946
t = ctx.saved_variables
2047
input = t[0]
@@ -189,6 +216,7 @@ def nllloss_double_backwards(ctx, ggI):
189216

190217
double_backwards_fns = {
191218
'ELU': elu_double_backwards,
219+
'GatedLinear': gatedlinear_double_backwards,
192220
'Hardshrink': hardshrink_double_backwards,
193221
'Hardtanh': hardtanh_double_backwards,
194222
'LeakyReLU': leakyrelu_double_backwards,

0 commit comments

Comments
 (0)