Skip to content

Commit 4853cc0

Browse files
chenyuntcsoumith
authored andcommitted
convert linalg.py to new-style functions (pytorch#1638)
1 parent ac1c674 commit 4853cc0

File tree

3 files changed

+54
-53
lines changed

3 files changed

+54
-53
lines changed

test/test_autograd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,15 +1409,15 @@ class dont_convert(tuple):
14091409
(Resize, (), ((S, S, S), torch.Size([S * S, S]))),
14101410
(Diag, (), ((S, S),), '2d'),
14111411
(Diag, (), ((S,),), '1d'),
1412-
(Diag, (1,), ((S, S),), '2d_1'),
1413-
(Diag, (2,), ((S, S),), '2d_2'),
1412+
(Diag, (), ((S, S), 1), '2d_1'),
1413+
(Diag, (), ((S, S), 2), '2d_2'),
14141414
(Tril, (), ((S, S),)),
1415-
(Tril, (2,), ((S, S),), 'idx'),
1415+
(Tril, (), ((S, S), 2), 'idx'),
14161416
(Triu, (), ((S, S),)),
1417-
(Triu, (2,), ((S, S),), 'idx'),
1417+
(Triu, (), ((S, S), 2), 'idx'),
14181418
(Trace, (), ((S, S),)),
14191419
(Cross, (), ((S, 3), (S, 3))),
1420-
(Cross, (1,), ((S, 3, S), (S, 3, S)), 'dim'),
1420+
(Cross, (), ((S, 3, S), (S, 3, S), 1), 'dim'),
14211421
(Inverse, (), ((S, S),), '', (), [skipIfNoLapack]),
14221422
(Clone, (), ((S, M, S),)),
14231423
(Squeeze, (), ((S, 1, M, 1),)),

torch/autograd/_functions/linalg.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,75 @@
11
import torch
22

33
from ..function import Function
4+
from ..variable import Variable
45

56

67
class Diag(Function):
78

8-
def __init__(self, diagonal_idx=0):
9-
super(Diag, self).__init__()
10-
self.diagonal_idx = diagonal_idx
11-
12-
def forward(self, input):
13-
return input.diag(self.diagonal_idx)
9+
@staticmethod
10+
def forward(ctx, input, diagonal_idx=0):
11+
ctx.diagonal_idx = diagonal_idx
12+
return input.diag(ctx.diagonal_idx)
1413

15-
def backward(self, grad_output):
16-
return grad_output.diag(self.diagonal_idx)
14+
@staticmethod
15+
def backward(ctx, grad_output):
16+
return grad_output.diag(ctx.diagonal_idx), None
1717

1818

1919
class Tril(Function):
2020

21-
def __init__(self, diagonal_idx=0):
22-
super(Tril, self).__init__()
23-
self.diagonal_idx = diagonal_idx
24-
25-
def forward(self, input):
26-
return input.tril(self.diagonal_idx)
21+
@staticmethod
22+
def forward(ctx, input, diagonal_idx=0):
23+
ctx.diagonal_idx = diagonal_idx
24+
return input.tril(ctx.diagonal_idx)
2725

28-
def backward(self, grad_output):
29-
return grad_output.tril(self.diagonal_idx)
26+
@staticmethod
27+
def backward(ctx, grad_output):
28+
return grad_output.tril(ctx.diagonal_idx), None
3029

3130

3231
class Triu(Function):
3332

34-
def __init__(self, diagonal_idx=0):
35-
super(Triu, self).__init__()
36-
self.diagonal_idx = diagonal_idx
37-
38-
def forward(self, input):
39-
return input.triu(self.diagonal_idx)
33+
@staticmethod
34+
def forward(ctx, input, diagnoal_idx=0):
35+
ctx.diagonal_idx = diagnoal_idx
36+
return input.triu(ctx.diagonal_idx)
4037

41-
def backward(self, grad_output):
42-
return grad_output.triu(self.diagonal_idx)
38+
@staticmethod
39+
def backward(ctx, grad_output):
40+
return grad_output.triu(ctx.diagonal_idx), None
4341

4442

4543
class Trace(Function):
4644

47-
def forward(self, input):
48-
self.isize = input.size()
49-
return input.new((input.trace(),))
45+
@staticmethod
46+
def forward(ctx, input):
47+
ctx.isize = input.size()
48+
return input.new((input.trace(), ))
5049

51-
def backward(self, grad_output):
52-
isize = self.isize
53-
grad_input = grad_output.new(isize).zero_()
54-
grad_input.view(-1)[::(isize[1] + 1)] = grad_output[0]
55-
return grad_input
50+
@staticmethod
51+
def backward(ctx, grad_output):
52+
isize = ctx.isize
53+
min_size = min(isize)
54+
grad_input = Variable(grad_output.data.new(isize).zero_()).view(-1)
55+
grad_input[::(isize[1] + 1)] = grad_output.expand(min_size)
56+
return grad_input.view(isize)
5657

5758

5859
class Cross(Function):
5960

60-
def __init__(self, dim=-1):
61-
self.dim = dim
62-
63-
def forward(self, input, other):
64-
self.save_for_backward(input, other)
65-
return torch.cross(input, other, self.dim)
61+
@staticmethod
62+
def forward(ctx, input, other, dim=-1):
63+
ctx.dim = dim
64+
ctx.save_for_backward(input, other)
65+
return torch.cross(input, other, ctx.dim)
6666

67-
def backward(self, grad_output):
68-
input, other = self.saved_tensors
69-
grad_input = torch.cross(other, grad_output, self.dim)
70-
grad_other = torch.cross(grad_output, input, self.dim)
71-
return grad_input, grad_other
67+
@staticmethod
68+
def backward(ctx, grad_output):
69+
input, other = ctx.saved_variables
70+
grad_input = other.cross(grad_output, ctx.dim)
71+
grad_other = grad_output.cross(input, ctx.dim)
72+
return grad_input, grad_other, None
7273

7374

7475
class Inverse(Function):

torch/autograd/variable.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -702,19 +702,19 @@ def permute(self, *permutation):
702702
return Permute.apply(self, permutation)
703703

704704
def diag(self, diagonal_idx=0):
705-
return Diag(diagonal_idx)(self)
705+
return Diag.apply(self, diagonal_idx)
706706

707707
def tril(self, diagonal_idx=0):
708-
return Tril(diagonal_idx)(self)
708+
return Tril.apply(self, diagonal_idx)
709709

710710
def triu(self, diagonal_idx=0):
711-
return Triu(diagonal_idx)(self)
711+
return Triu.apply(self, diagonal_idx)
712712

713713
def trace(self):
714-
return Trace()(self)
714+
return Trace.apply(self)
715715

716716
def cross(self, other, dim=-1):
717-
return Cross(dim)(self, other)
717+
return Cross.apply(self, other)
718718

719719
def inverse(self):
720720
return Inverse.apply(self)

0 commit comments

Comments
 (0)