Skip to content

Commit 09abaa2

Browse files
authored
make keepdim backcompat warnings emit in autograd as well (pytorch#2157)
1 parent 575a4a9 commit 09abaa2

File tree

4 files changed

+93
-31
lines changed

4 files changed

+93
-31
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def run(self):
236236
CXXNAME = os.getenv('CXX', 'g++')
237237
path = subprocess.check_output([CXXNAME, '-print-file-name=libstdc++.a'])
238238
path = path[:-1]
239-
if type(path) != str: # python 3
239+
if type(path) != str: # python 3
240240
path = path.decode(sys.stdout.encoding)
241241
extra_link_args += [path]
242242

test/test_autograd.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,52 @@ def backward(self, grad_output):
13661366
c.backward(torch.ones(c.size()))
13671367
self.assertEqual(x.grad.data, torch.ones(x.size()))
13681368

1369+
def test_keepdim_warning(self):
1370+
torch.utils.backcompat.keepdim_warning.enabled = True
1371+
x = Variable(torch.randn(3, 4), requires_grad=True)
1372+
1373+
def run_backward(y):
1374+
y_ = y
1375+
if type(y) is tuple:
1376+
y_ = y[0]
1377+
# check that backward runs smooth
1378+
y_.backward(y_.data.new(y_.size()).normal_())
1379+
1380+
def keepdim_check(f):
1381+
with warnings.catch_warnings(record=True) as w:
1382+
warnings.simplefilter("always")
1383+
y = f(x, 1)
1384+
self.assertTrue(len(w) == 1)
1385+
self.assertTrue(issubclass(w[-1].category, UserWarning))
1386+
self.assertTrue("keepdim" in str(w[-1].message))
1387+
run_backward(y)
1388+
self.assertEqual(x.size(), x.grad.size())
1389+
1390+
# check against explicit keepdim
1391+
y2 = f(x, 1, keepdim=False)
1392+
self.assertEqual(y, y2)
1393+
run_backward(y2)
1394+
1395+
y3 = f(x, 1, keepdim=True)
1396+
if type(y3) == tuple:
1397+
y3 = (y3[0].squeeze(1), y3[1].squeeze(1))
1398+
else:
1399+
y3 = y3.squeeze(1)
1400+
self.assertEqual(y, y3)
1401+
run_backward(y3)
1402+
1403+
keepdim_check(torch.sum)
1404+
keepdim_check(torch.prod)
1405+
keepdim_check(torch.mean)
1406+
keepdim_check(torch.max)
1407+
keepdim_check(torch.min)
1408+
keepdim_check(torch.mode)
1409+
keepdim_check(torch.median)
1410+
keepdim_check(torch.kthvalue)
1411+
keepdim_check(torch.var)
1412+
keepdim_check(torch.std)
1413+
torch.utils.backcompat.keepdim_warning.enabled = False
1414+
13691415

13701416
def index_variable(shape, max_indices):
13711417
if not isinstance(shape, tuple):

torch/autograd/_functions/reduce.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88
class Sum(Function):
99

1010
@staticmethod
11-
def forward(ctx, input, dim=None, keepdim=False):
11+
def forward(ctx, input, dim=None, keepdim=None):
1212
ctx.dim = dim
13-
ctx.keepdim = keepdim
13+
ctx.keepdim = False if keepdim is None else keepdim
1414
ctx.input_size = input.size()
1515
if dim is None:
1616
return input.new((input.sum(),))
1717
else:
18-
return input.sum(dim, keepdim)
18+
if keepdim is not None:
19+
return input.sum(dim, keepdim=keepdim)
20+
else:
21+
return input.sum(dim)
1922

2023
@staticmethod
2124
def backward(ctx, grad_output):
@@ -33,16 +36,19 @@ def backward(ctx, grad_output):
3336
class Prod(Function):
3437

3538
@staticmethod
36-
def forward(ctx, input, dim=None, keepdim=False):
39+
def forward(ctx, input, dim=None, keepdim=None):
3740
ctx.dim = dim
38-
ctx.keepdim = keepdim
41+
ctx.keepdim = False if keepdim is None else keepdim
3942
ctx.input_size = input.size()
4043
if dim is None:
4144
ctx.result = input.prod()
4245
ctx.save_for_backward(input)
4346
return input.new((ctx.result,))
4447
else:
45-
output = input.prod(dim, keepdim)
48+
if keepdim is not None:
49+
output = input.prod(dim, keepdim=keepdim)
50+
else:
51+
output = input.prod(dim)
4652
ctx.save_for_backward(input, output)
4753
return output
4854

@@ -105,14 +111,17 @@ def reverse_dim(var, dim):
105111
class Mean(Function):
106112

107113
@staticmethod
108-
def forward(ctx, input, dim=None, keepdim=False):
114+
def forward(ctx, input, dim=None, keepdim=None):
109115
ctx.dim = dim
110-
ctx.keepdim = keepdim
116+
ctx.keepdim = False if keepdim is None else keepdim
111117
ctx.input_size = input.size()
112118
if dim is None:
113119
return input.new((input.mean(),))
114120
else:
115-
return input.mean(dim, keepdim)
121+
if keepdim is not None:
122+
return input.mean(dim, keepdim=keepdim)
123+
else:
124+
return input.mean(dim)
116125

117126
@staticmethod
118127
def backward(ctx, grad_output):
@@ -136,10 +145,10 @@ class _SelectionFunction(Function):
136145
# kthvalue not only requires us to pass a dim, but also preceed it with k.
137146

138147
@classmethod
139-
def forward(cls, ctx, input, dim=None, keepdim=False, additional_args=tuple()):
148+
def forward(cls, ctx, input, dim=None, keepdim=None, additional_args=tuple()):
140149
fn = getattr(input, cls.__name__.lower())
141150
ctx.dim = dim
142-
ctx.keepdim = keepdim
151+
ctx.keepdim = False if keepdim is None else keepdim
143152
ctx.additional_args = additional_args
144153
ctx.input_size = input.size()
145154
if ctx.dim is None and cls.has_all_reduce:
@@ -151,10 +160,13 @@ def forward(cls, ctx, input, dim=None, keepdim=False, additional_args=tuple()):
151160
dim = input.dim() - 1
152161
else:
153162
dim = ctx.dim
154-
args = (dim, keepdim)
163+
args = (dim,)
155164
if additional_args:
156165
args = additional_args + args
157-
output, indices = fn(*args)
166+
if keepdim is not None:
167+
output, indices = fn(*args, keepdim=keepdim)
168+
else:
169+
output, indices = fn(*args)
158170
ctx.save_for_backward(indices)
159171
ctx.mark_non_differentiable(indices)
160172
return output, indices
@@ -200,24 +212,27 @@ class Kthvalue(_SelectionFunction):
200212
has_all_reduce = False
201213

202214
@classmethod
203-
def forward(cls, ctx, input, k, dim=None, keepdim=False):
215+
def forward(cls, ctx, input, k, dim=None, keepdim=None):
204216
return super(Kthvalue, cls).forward(ctx, input, dim, keepdim, (k,))
205217

206218

207219
class Norm(Function):
208220

209221
@staticmethod
210-
def forward(ctx, input, p=2, dim=None, keepdim=False):
222+
def forward(ctx, input, p=2, dim=None, keepdim=None):
211223
ctx.p = p
212224
ctx.dim = dim
213-
ctx.keepdim = keepdim
225+
ctx.keepdim = False if keepdim is None else keepdim
214226

215227
if dim is None:
216228
ctx.norm = input.norm(p)
217229
ctx.save_for_backward(input)
218230
return input.new((ctx.norm,))
219231
else:
220-
output = input.norm(p, dim, keepdim)
232+
if keepdim is not None:
233+
output = input.norm(p, dim, keepdim=keepdim)
234+
else:
235+
output = input.norm(p, dim)
221236
ctx.save_for_backward(input, output)
222237
return output
223238

torch/autograd/variable.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -449,32 +449,32 @@ def lerp(self, tensor, weight):
449449
def rsqrt(self):
450450
return Rsqrt.apply(self)
451451

452-
def sum(self, dim=None, keepdim=False):
452+
def sum(self, dim=None, keepdim=None):
453453
return Sum.apply(self, dim, keepdim)
454454

455-
def prod(self, dim=None, keepdim=False):
455+
def prod(self, dim=None, keepdim=None):
456456
return Prod.apply(self, dim, keepdim)
457457

458-
def mean(self, dim=None, keepdim=False):
458+
def mean(self, dim=None, keepdim=None):
459459
return Mean.apply(self, dim, keepdim)
460460

461-
def max(self, dim=None, keepdim=False):
461+
def max(self, dim=None, keepdim=None):
462462
if isinstance(dim, Variable):
463463
return Cmax.apply(self, dim)
464464
return Max.apply(self, dim, keepdim)
465465

466-
def min(self, dim=None, keepdim=False):
466+
def min(self, dim=None, keepdim=None):
467467
if isinstance(dim, Variable):
468468
return Cmin.apply(self, dim)
469469
return Min.apply(self, dim, keepdim)
470470

471-
def mode(self, dim=None, keepdim=False):
471+
def mode(self, dim=None, keepdim=None):
472472
return Mode.apply(self, dim, keepdim)
473473

474-
def median(self, dim=None, keepdim=False):
474+
def median(self, dim=None, keepdim=None):
475475
return Median.apply(self, dim, keepdim)
476476

477-
def kthvalue(self, k, dim=None, keepdim=False):
477+
def kthvalue(self, k, dim=None, keepdim=None):
478478
return Kthvalue.apply(self, k, dim, keepdim)
479479

480480
def sort(self, dim=None, descending=False):
@@ -508,20 +508,21 @@ def cumprod(self, dim):
508508
def unfold(self, dim, size, step):
509509
return Unfold.apply(self, dim, size, step)
510510

511-
def var(self, dim=None, keepdim=False, unbiased=True):
511+
def var(self, dim=None, keepdim=None, unbiased=True):
512+
keepdim_ = False if keepdim is None else keepdim
512513
mean = self.mean(dim, keepdim)
513514
if dim is None:
514515
mean = mean.view(*(1 for s in self.size()))
515516
# we could just set keepdim to True, but this preserves some fidelity
516-
elif keepdim is False and self.dim() != 1:
517+
elif keepdim_ is False and self.dim() != 1:
517518
mean = mean.unsqueeze(dim)
518519
mean_expanded = mean.expand_as(self)
519520
zero_centered = self.sub(mean_expanded)
520-
var = zero_centered.mul(zero_centered).sum(dim, keepdim)
521+
var = zero_centered.mul(zero_centered).sum(dim, keepdim=keepdim_)
521522
numel = self.numel() if dim is None else self.size(dim)
522523
return var.div(numel - int(unbiased))
523524

524-
def std(self, dim=None, keepdim=False, unbiased=True):
525+
def std(self, dim=None, keepdim=None, unbiased=True):
525526
return self.var(dim, keepdim, unbiased).sqrt()
526527

527528
def renorm(self, p, dim, maxnorm):
@@ -626,7 +627,7 @@ def addcmul_(self, *args):
626627
def addcdiv_(self, *args):
627628
return self._addcop(Addcdiv, args, True)
628629

629-
def norm(self, p=2, dim=None, keepdim=False):
630+
def norm(self, p=2, dim=None, keepdim=None):
630631
return Norm.apply(self, p, dim, keepdim)
631632

632633
def dist(self, tensor, p=2):

0 commit comments

Comments
 (0)