Skip to content

Commit 511cb20

Browse files
fmassasoumith
authored andcommitted
Add Gesv to autograd (pytorch#1733)
* Add Gesv to autograd * Add TODO for backprop through LU
1 parent e3305eb commit 511cb20

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

test/test_autograd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,6 +1419,7 @@ class dont_convert(tuple):
14191419
(Cross, (), ((S, 3), (S, 3))),
14201420
(Cross, (), ((S, 3, S), (S, 3, S), 1), 'dim'),
14211421
(Inverse, (), ((S, S),), '', (), [skipIfNoLapack]),
1422+
(Gesv, (), ((S, S), (S, S)), '', (), [skipIfNoLapack]),
14221423
(Clone, (), ((S, M, S),)),
14231424
(Squeeze, (), ((S, 1, M, 1),)),
14241425
# TODO: enable neg dim checks
@@ -1552,6 +1553,7 @@ class dont_convert(tuple):
15521553
('cross', (S, 3), ((S, 3),)),
15531554
('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
15541555
('inverse', (S, S), (), '', (), [skipIfNoLapack]),
1556+
('gesv', (S, S), ((S, S),), '', (), [skipIfNoLapack]),
15551557
('clone', (S, M, S), ()),
15561558
('eq', (S, S, S), ((S, S, S),)),
15571559
('ne', (S, S, S), ((S, S, S),)),

torch/autograd/_functions/linalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,21 @@ def forward(ctx, input):
8484
def backward(ctx, grad_output):
8585
inverse, = ctx.saved_variables
8686
return -torch.mm(inverse.t(), torch.mm(grad_output, inverse.t()))
87+
88+
89+
class Gesv(Function):
90+
91+
@staticmethod
92+
def forward(ctx, b, a):
93+
# TODO see if one can backprop through LU
94+
X, LU = torch.gesv(b, a)
95+
ctx.save_for_backward(X, a)
96+
ctx.mark_non_differentiable(LU)
97+
return X, LU
98+
99+
@staticmethod
100+
def backward(ctx, grad_output, grad_LU=None):
101+
X, a = ctx.saved_variables
102+
grad_b, _ = torch.gesv(grad_output, a.t())
103+
grad_a = -torch.mm(grad_b, X.t())
104+
return grad_b, grad_a

torch/autograd/variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,9 @@ def cross(self, other, dim=-1):
719719
def inverse(self):
720720
return Inverse.apply(self)
721721

722+
def gesv(self, a):
723+
return Gesv.apply(self, a)
724+
722725
def multinomial(self, num_samples=1, with_replacement=False):
723726
return Multinomial(num_samples, with_replacement)(self)
724727

0 commit comments

Comments
 (0)