Skip to content

Commit ac1c674

Browse files
gchanansoumith
authored andcommitted
Fix a couple of selection reduce function autograd bugs (pytorch#1702)
* Fix Median/Mode autograd functions. * Fix kthvalue autograd function. * Double backward for selection reduce functions.
1 parent eba3dc8 commit ac1c674

File tree

3 files changed

+58
-51
lines changed

3 files changed

+58
-51
lines changed

test/test_autograd.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,16 +1368,19 @@ class dont_convert(tuple):
13681368
(Unfold, (), ((S, S, S), 1, 3, 1)),
13691369
(Unfold, (), ((S, S, S), 2, 3, 2), 'lastdim'),
13701370
(Min, (), ((S, S, S),),),
1371-
(Max, (1,), ((S, S, S),), 'dim', [0]),
1372-
(Min, (1,), ((S, S, S),), 'dim', [0]),
1373-
(Max, (1, False), ((S, S, S),), 'keepdim_false_dim', [0]),
1374-
(Min, (1, False), ((S, S, S),), 'keepdim_false_dim', [0]),
1375-
(Mode, (1,), ((S, S, S),), 'dim', [0]),
1376-
(Mode, (1, False,), ((S, S, S),), 'keepdim_false_dim', [0]),
1377-
(Kthvalue, (2, 0), ((S, S, S),),),
1378-
(Kthvalue, (2, 0, False), ((S, S, S),), "keepdim_false"),
1379-
(Median, (0,), ((S, S, S),),),
1380-
(Median, (0, False, ), ((S, S, S),), "keepdim_false"),
1371+
(Max, (), ((S, S, S), 1), 'dim', [0]),
1372+
(Min, (), ((S, S, S), 1), 'dim', [0]),
1373+
(Max, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
1374+
(Min, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
1375+
(Mode, (), ((S, S, S),),),
1376+
(Mode, (), ((S, S, S), 1), 'dim', [0]),
1377+
(Mode, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
1378+
(Kthvalue, (), ((S, S, S), 2),),
1379+
(Kthvalue, (), ((S, S, S), 2, 0), 'dim0'),
1380+
(Kthvalue, (), ((S, S, S), 2, 0, False), "keepdim_false"),
1381+
(Median, (), ((S, S, S),),),
1382+
(Median, (), ((S, S, S), 0), 'dim0'),
1383+
(Median, (), ((S, S, S), 0, False), "keepdim_false"),
13811384
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
13821385
(Norm, (), ((S, S, S),), '2'),
13831386
(Norm, (3,), ((S, S, S),), '3'),
@@ -1492,8 +1495,13 @@ class dont_convert(tuple):
14921495
('mean', (S, S, S), ()),
14931496
('mean', (S, S, S), (1,), 'dim', [0]),
14941497
('mean', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
1498+
('kthvalue', (S, S, S), (2,)),
1499+
('kthvalue', (S, S, S), (2, 1,), 'dim', [1]),
1500+
('kthvalue', (S, S, S), (2, 1, False,), 'keepdim_false_dim', [1]),
1501+
('median', (S, S, S), ()),
14951502
('median', (S, S, S), (1,), 'dim', [0]),
14961503
('median', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
1504+
('mode', (S, S, S), ()),
14971505
('mode', (S, S, S), (1,), 'dim', [0]),
14981506
('mode', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
14991507
('sum', (S, S, S), ()),

torch/autograd/_functions/reduce.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -130,51 +130,50 @@ class _SelectionFunction(Function):
130130
# additional_args is prepended before dim when calling the tensor
131131
# function. It's a no-op for subclasses other than kthvalue.
132132
# kthvalue not only requires us to pass a dim, but also preceed it with k.
133-
additional_args = tuple()
134-
135-
def __init__(self, dim=None, keepdim=True):
136-
super(_SelectionFunction, self).__init__()
137-
self.dim = dim
138-
self.keepdim = keepdim
139-
140-
def forward(self, input):
141-
fn = getattr(input, type(self).__name__.lower())
142-
self.input_size = input.size()
143-
if self.dim is None and self.has_all_reduce:
144-
value = fn(*self.additional_args)
145-
self.indices = tuple(input.eq(value).nonzero()[0])
133+
134+
@classmethod
135+
def forward(cls, ctx, input, dim=None, keepdim=True, additional_args=tuple()):
136+
fn = getattr(input, cls.__name__.lower())
137+
ctx.dim = dim
138+
ctx.keepdim = keepdim
139+
ctx.additional_args = additional_args
140+
ctx.input_size = input.size()
141+
if ctx.dim is None and cls.has_all_reduce:
142+
value = fn(*additional_args)
143+
ctx.indices_tuple = tuple(input.eq(value).nonzero()[0])
146144
return input.new((value,))
147145
else:
148-
if self.dim is None:
146+
if ctx.dim is None:
149147
dim = input.dim() - 1
150148
else:
151-
dim = self.dim
152-
args = (dim, self.keepdim)
153-
if self.additional_args:
154-
args = self.additional_args + args
149+
dim = ctx.dim
150+
args = (dim, keepdim)
151+
if additional_args:
152+
args = additional_args + args
155153
output, indices = fn(*args)
156-
self.save_for_backward(indices)
157-
self.mark_non_differentiable(indices)
154+
ctx.save_for_backward(indices)
155+
ctx.mark_non_differentiable(indices)
158156
return output, indices
159157

160-
def backward(self, grad_output, grad_indices=None):
161-
grad_input = grad_output.new(*self.input_size).zero_()
162-
if self.dim is None and self.has_all_reduce:
163-
grad_input[self.indices] = grad_output[0]
158+
@classmethod
159+
def backward(cls, ctx, grad_output, grad_indices=None):
160+
grad_input = Variable(grad_output.data.new(*ctx.input_size).zero_())
161+
if ctx.dim is None and cls.has_all_reduce:
162+
grad_input[ctx.indices_tuple] = grad_output.data[0]
164163
else:
165-
if self.dim is None:
166-
dim = input.dim() - 1
164+
if ctx.dim is None:
165+
dim = len(ctx.input_size) - 1
167166
else:
168-
dim = self.dim
167+
dim = ctx.dim
169168

170-
indices, = self.saved_tensors
171-
if self.keepdim is False:
169+
indices, = ctx.saved_variables
170+
if ctx.keepdim is False:
172171
grad_output = grad_output.unsqueeze(dim)
173172
grad_indices = grad_indices.unsqueeze(dim)
174173
indices = indices.unsqueeze(dim)
175174

176175
grad_input.scatter_(dim, indices, grad_output)
177-
return grad_input
176+
return grad_input, None, None, None
178177

179178

180179
class Max(_SelectionFunction):
@@ -196,9 +195,9 @@ class Median(_SelectionFunction):
196195
class Kthvalue(_SelectionFunction):
197196
has_all_reduce = False
198197

199-
def __init__(self, k, dim=None, keepdim=True):
200-
super(Kthvalue, self).__init__(dim, keepdim)
201-
self.additional_args = (k,)
198+
@classmethod
199+
def forward(cls, ctx, input, k, dim=None, keepdim=True):
200+
return super(Kthvalue, cls).forward(ctx, input, dim, keepdim, (k,))
202201

203202

204203
class Norm(Function):

torch/autograd/variable.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,21 +451,21 @@ def mean(self, dim=None, keepdim=True):
451451
def max(self, dim=None, keepdim=True):
452452
if isinstance(dim, Variable):
453453
return Cmax.apply(self, dim)
454-
return Max(dim, keepdim)(self)
454+
return Max.apply(self, dim, keepdim)
455455

456456
def min(self, dim=None, keepdim=True):
457457
if isinstance(dim, Variable):
458458
return Cmin.apply(self, dim)
459-
return Min(dim, keepdim)(self)
459+
return Min.apply(self, dim, keepdim)
460460

461-
def mode(self, dim, keepdim=True):
462-
return Mode(dim, keepdim)(self)
461+
def mode(self, dim=None, keepdim=True):
462+
return Mode.apply(self, dim, keepdim)
463463

464-
def median(self, dim, keepdim=True):
465-
return Median(dim, keepdim)(self)
464+
def median(self, dim=None, keepdim=True):
465+
return Median.apply(self, dim, keepdim)
466466

467-
def kthvalue(self, dim, keepdim=True):
468-
return Kthvalue(dim, keepdim)(self)
467+
def kthvalue(self, k, dim=None, keepdim=True):
468+
return Kthvalue.apply(self, k, dim, keepdim)
469469

470470
def sort(self, dim=None, descending=False):
471471
return Sort.apply(self, dim, descending, True)

0 commit comments

Comments
 (0)