Skip to content

Commit ae2b2cb

Browse files
gchanansoumith
authored andcommitted
Make keepdim work with autograd.
1 parent f4cf1d6 commit ae2b2cb

File tree

9 files changed

+106
-78
lines changed

9 files changed

+106
-78
lines changed

test/common_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
dict(
8989
module_name='Softmax',
9090
input_size=(10, 20),
91-
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1).expand(10, 20))
91+
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))
9292
),
9393
dict(
9494
module_name='Softmax2d',
@@ -98,7 +98,7 @@
9898
dict(
9999
module_name='LogSoftmax',
100100
input_size=(10, 20),
101-
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand(10, 20)).log_()
101+
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()
102102
),
103103
dict(
104104
module_name='LogSoftmax',

test/test_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ def test_stochastic(self):
10411041
x = Variable(torch.rand(2, 10), requires_grad=True)
10421042
stddevs = Variable(torch.rand(2, 10) * 5, requires_grad=True)
10431043
y = (x * 2).clamp(0, 1)
1044-
y = y / y.sum(1).expand_as(y)
1044+
y = y / y.sum(1, True).expand_as(y)
10451045
samples_multi = y.multinomial(5)
10461046
samples_multi_flat = y[0].multinomial(5)
10471047
samples_bernoulli = y.bernoulli()

test/test_torch.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,17 @@ def _testSelection(self, torchfn, mathfn):
159159
# with indices
160160
m1 = torch.randn(100, 100)
161161
res1val, res1ind = torchfn(m1, 1)
162-
res2val = m1[:, 0:1].clone()
162+
res2val = m1[:, 0:1].clone().squeeze()
163163
res2ind = res1ind.clone().fill_(0)
164164
for i, j in iter_indices(m1):
165-
if mathfn(res2val[i, 0], m1[i, j]) != res2val[i, 0]:
166-
res2val[i, 0] = m1[i, j]
167-
res2ind[i, 0] = j
165+
if mathfn(res2val[i], m1[i, j]) != res2val[i]:
166+
res2val[i] = m1[i, j]
167+
res2ind[i] = j
168168

169169
maxerr = 0
170170
for i in range(res1val.size(0)):
171-
maxerr = max(maxerr, abs(res1val[i][0] - res2val[i][0]))
172-
self.assertEqual(res1ind[i][0], res2ind[i][0])
171+
maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
172+
self.assertEqual(res1ind[i], res2ind[i])
173173
self.assertLessEqual(abs(maxerr), 1e-5)
174174

175175
# NaNs
@@ -514,22 +514,22 @@ def test_addbmm(self):
514514
res2 = torch.Tensor().resize_as_(res[0]).zero_()
515515

516516
res2.addbmm_(b1, b2)
517-
self.assertEqual(res2, res.sum(0)[0])
517+
self.assertEqual(res2, res.sum(0))
518518

519519
res2.addbmm_(1, b1, b2)
520-
self.assertEqual(res2, res.sum(0)[0] * 2)
520+
self.assertEqual(res2, res.sum(0) * 2)
521521

522522
res2.addbmm_(1., .5, b1, b2)
523-
self.assertEqual(res2, res.sum(0)[0] * 2.5)
523+
self.assertEqual(res2, res.sum(0) * 2.5)
524524

525525
res3 = torch.addbmm(1, res2, 0, b1, b2)
526526
self.assertEqual(res3, res2)
527527

528528
res4 = torch.addbmm(1, res2, .5, b1, b2)
529-
self.assertEqual(res4, res.sum(0)[0] * 3)
529+
self.assertEqual(res4, res.sum(0) * 3)
530530

531531
res5 = torch.addbmm(0, res2, 1, b1, b2)
532-
self.assertEqual(res5, res.sum(0)[0])
532+
self.assertEqual(res5, res.sum(0))
533533

534534
res6 = torch.addbmm(.1, res2, .5, b1, b2)
535535
self.assertEqual(res6, res2 * .1 + res.sum(0) * .5)
@@ -744,7 +744,7 @@ def renorm(matrix, value, dim, max_norm):
744744
m1 = matrix.transpose(dim, 0).contiguous()
745745
# collapse non-dim dimensions.
746746
m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
747-
norms = m2.norm(value, 1)
747+
norms = m2.norm(value, 1, True)
748748
# clip
749749
new_norms = norms.clone()
750750
new_norms[torch.gt(norms, max_norm)] = max_norm
@@ -1070,23 +1070,23 @@ def test_kthvalue(self):
10701070
res1val, res1ind = torch.kthvalue(x, k)
10711071
res2val, res2ind = torch.sort(x)
10721072

1073-
self.assertEqual(res1val[:, :, 0], res2val[:, :, k - 1], 0)
1074-
self.assertEqual(res1ind[:, :, 0], res2ind[:, :, k - 1], 0)
1073+
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
1074+
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
10751075
# test use of result tensors
10761076
k = random.randint(1, SIZE)
10771077
res1val = torch.Tensor()
10781078
res1ind = torch.LongTensor()
10791079
torch.kthvalue(x, k, out=(res1val, res1ind))
10801080
res2val, res2ind = torch.sort(x)
1081-
self.assertEqual(res1val[:, :, 0], res2val[:, :, k - 1], 0)
1082-
self.assertEqual(res1ind[:, :, 0], res2ind[:, :, k - 1], 0)
1081+
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
1082+
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
10831083

10841084
# test non-default dim
10851085
k = random.randint(1, SIZE)
10861086
res1val, res1ind = torch.kthvalue(x, k, 0)
10871087
res2val, res2ind = torch.sort(x, 0)
1088-
self.assertEqual(res1val[0], res2val[k - 1], 0)
1089-
self.assertEqual(res1ind[0], res2ind[k - 1], 0)
1088+
self.assertEqual(res1val, res2val[k - 1], 0)
1089+
self.assertEqual(res1ind, res2ind[k - 1], 0)
10901090

10911091
# non-contiguous
10921092
y = x.narrow(1, 0, 1)
@@ -1110,12 +1110,12 @@ def test_median(self):
11101110
x = torch.rand(size, size)
11111111
x0 = x.clone()
11121112

1113-
res1val, res1ind = torch.median(x)
1113+
res1val, res1ind = torch.median(x, False)
11141114
res2val, res2ind = torch.sort(x)
11151115
ind = int(math.floor((size + 1) / 2) - 1)
11161116

1117-
self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)
1118-
self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)
1117+
self.assertEqual(res2val.select(1, ind), res1val, 0)
1118+
self.assertEqual(res2val.select(1, ind), res1val, 0)
11191119

11201120
# Test use of result tensor
11211121
res2val = torch.Tensor()
@@ -1127,8 +1127,8 @@ def test_median(self):
11271127
# Test non-default dim
11281128
res1val, res1ind = torch.median(x, 0)
11291129
res2val, res2ind = torch.sort(x, 0)
1130-
self.assertEqual(res1val[0], res2val[ind], 0)
1131-
self.assertEqual(res1ind[0], res2ind[ind], 0)
1130+
self.assertEqual(res1val, res2val[ind], 0)
1131+
self.assertEqual(res1ind, res2ind[ind], 0)
11321132

11331133
# input unchanged
11341134
self.assertEqual(x, x0, 0)
@@ -1140,9 +1140,9 @@ def test_mode(self):
11401140
x0 = x.clone()
11411141

11421142
# Pre-calculated results.
1143-
res1val = torch.Tensor(SIZE, 1).fill_(1)
1143+
res1val = torch.Tensor(SIZE).fill_(1)
11441144
# The indices are the position of the last appearance of the mode element.
1145-
res1ind = torch.LongTensor(SIZE, 1).fill_(1)
1145+
res1ind = torch.LongTensor(SIZE).fill_(1)
11461146
res1ind[0] = SIZE - 1
11471147
res1ind[1] = SIZE - 1
11481148

@@ -1160,8 +1160,8 @@ def test_mode(self):
11601160

11611161
# Test non-default dim
11621162
res2val, res2ind = torch.mode(x, 0)
1163-
self.assertEqual(res1val.view(1, SIZE), res2val, 0)
1164-
self.assertEqual(res1ind.view(1, SIZE), res2ind, 0)
1163+
self.assertEqual(res1val, res2val, 0)
1164+
self.assertEqual(res1ind, res2ind, 0)
11651165

11661166
# input unchanged
11671167
self.assertEqual(x, x0, 0)
@@ -2217,7 +2217,7 @@ def _test_gather(self, cast, test_bounds=True):
22172217
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
22182218

22192219
src = cast(torch.randn(3, 4, 5))
2220-
expected, idx = src.max(2)
2220+
expected, idx = src.max(2, True)
22212221
expected = cast(expected)
22222222
idx = cast(idx)
22232223
actual = torch.gather(src, 2, idx)

torch/autograd/_functions/reduce.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,41 @@
77
class Sum(Function):
88

99
@staticmethod
10-
def forward(ctx, input, dim=None):
10+
def forward(ctx, input, dim=None, keepdim=False):
1111
ctx.dim = dim
12+
ctx.keepdim = keepdim
1213
ctx.input_size = input.size()
1314
if dim is None:
1415
return input.new((input.sum(),))
1516
else:
16-
return input.sum(dim)
17+
return input.sum(dim, keepdim)
1718

1819
@staticmethod
1920
def backward(ctx, grad_output):
2021
if ctx.dim is None:
21-
return grad_output.expand(ctx.input_size), None
22+
return grad_output.expand(ctx.input_size), None, None
2223
else:
24+
if ctx.keepdim is False:
25+
grad_output = grad_output.unsqueeze(ctx.dim)
26+
2327
repeats = [1 for _ in ctx.input_size]
2428
repeats[ctx.dim] = ctx.input_size[ctx.dim]
25-
return grad_output.repeat(*repeats), None
29+
return grad_output.repeat(*repeats), None, None
2630

2731

2832
class Prod(Function):
2933

3034
@staticmethod
31-
def forward(ctx, input, dim=None):
35+
def forward(ctx, input, dim=None, keepdim=False):
3236
ctx.dim = dim
37+
ctx.keepdim = keepdim
3338
ctx.input_size = input.size()
3439
if dim is None:
3540
ctx.result = input.prod()
3641
ctx.save_for_backward(input)
3742
return input.new((ctx.result,))
3843
else:
39-
output = input.prod(dim)
44+
output = input.prod(dim, keepdim)
4045
ctx.save_for_backward(input, output)
4146
return output
4247

@@ -59,8 +64,11 @@ def backward(ctx, grad_output):
5964
else:
6065
input, output = ctx.saved_variables
6166
dim = ctx.dim if ctx.dim >= 0 else ctx.dim + input.dim()
67+
if ctx.keepdim is False:
68+
grad_output = grad_output.unsqueeze(dim)
69+
6270
zero_mask = input == 0
63-
slice_zero_count = zero_mask.sum(dim)
71+
slice_zero_count = zero_mask.sum(dim, True)
6472
total_zeros = slice_zero_count.sum()
6573
grad_input = grad_output.mul(output).expand_as(input).div(input)
6674
if total_zeros == 0:
@@ -93,24 +101,28 @@ def backward(ctx, grad_output):
93101
class Mean(Function):
94102

95103
@staticmethod
96-
def forward(ctx, input, dim=None):
104+
def forward(ctx, input, dim=None, keepdim=False):
97105
ctx.dim = dim
106+
ctx.keepdim = keepdim
98107
ctx.input_size = input.size()
99108
if dim is None:
100109
return input.new((input.mean(),))
101110
else:
102-
return input.mean(dim)
111+
return input.mean(dim, keepdim)
103112

104113
@staticmethod
105114
def backward(ctx, grad_output):
106115
if ctx.dim is None:
107116
grad_input_val = grad_output / reduce(lambda x, y: x * y, ctx.input_size, 1)
108-
return grad_input_val.expand(ctx.input_size), None
117+
return grad_input_val.expand(ctx.input_size), None, None
109118
else:
119+
if ctx.keepdim is False:
120+
grad_output = grad_output.unsqueeze(ctx.dim)
121+
110122
repeats = [1 for _ in ctx.input_size]
111123
dim_size = ctx.input_size[ctx.dim]
112124
repeats[ctx.dim] = dim_size
113-
return grad_output.repeat(*repeats).div_(dim_size), None
125+
return grad_output.repeat(*repeats).div_(dim_size), None, None
114126

115127

116128
class _SelectionFunction(Function):
@@ -120,9 +132,10 @@ class _SelectionFunction(Function):
120132
# kthvalue not only requires us to pass a dim, but also preceed it with k.
121133
additional_args = tuple()
122134

123-
def __init__(self, dim=None):
135+
def __init__(self, dim=None, keepdim=False):
124136
super(_SelectionFunction, self).__init__()
125137
self.dim = dim
138+
self.keepdim = keepdim
126139

127140
def forward(self, input):
128141
fn = getattr(input, type(self).__name__.lower())
@@ -136,7 +149,7 @@ def forward(self, input):
136149
dim = input.dim() - 1
137150
else:
138151
dim = self.dim
139-
args = (dim,)
152+
args = (dim, self.keepdim)
140153
if self.additional_args:
141154
args = self.additional_args + args
142155
output, indices = fn(*args)
@@ -153,7 +166,13 @@ def backward(self, grad_output, grad_indices=None):
153166
dim = input.dim() - 1
154167
else:
155168
dim = self.dim
169+
156170
indices, = self.saved_tensors
171+
if self.keepdim is False:
172+
grad_output = grad_output.unsqueeze(dim)
173+
grad_indices = grad_indices.unsqueeze(dim)
174+
indices = indices.unsqueeze(dim)
175+
157176
grad_input.scatter_(dim, indices, grad_output)
158177
return grad_input
159178

@@ -177,25 +196,26 @@ class Median(_SelectionFunction):
177196
class Kthvalue(_SelectionFunction):
178197
has_all_reduce = False
179198

180-
def __init__(self, k, dim=None):
181-
super(Kthvalue, self).__init__(dim)
199+
def __init__(self, k, dim=None, keepdim=False):
200+
super(Kthvalue, self).__init__(dim, keepdim)
182201
self.additional_args = (k,)
183202

184203

185204
class Norm(Function):
186205

187-
def __init__(self, norm_type=2, dim=None):
206+
def __init__(self, norm_type=2, dim=None, keepdim=False):
188207
super(Norm, self).__init__()
189208
self.norm_type = norm_type
190209
self.dim = dim
210+
self.keepdim = keepdim
191211

192212
def forward(self, input):
193213
if self.dim is None:
194214
self.norm = input.norm(self.norm_type)
195215
self.save_for_backward(input)
196216
return input.new((self.norm,))
197217
else:
198-
output = input.norm(self.norm_type, self.dim)
218+
output = input.norm(self.norm_type, self.dim, self.keepdim)
199219
self.save_for_backward(input, output)
200220
return output
201221

@@ -210,6 +230,11 @@ def backward(self, grad_output):
210230
return input.mul(pow).mul(scale)
211231
else:
212232
input, output = self.saved_tensors
233+
234+
if self.keepdim is False:
235+
grad_output = grad_output.unsqueeze(self.dim)
236+
output = output.unsqueeze(self.dim)
237+
213238
big_grad_output = grad_output.expand_as(input)
214239
if self.norm_type == 2:
215240
big_output = output.expand_as(input)

torch/autograd/_functions/stochastic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def backward(self, reward):
2424
probs = probs.unsqueeze(0)
2525
samples = samples.unsqueeze(0)
2626
# normalize probs (multinomial accepts weights)
27-
probs /= probs.sum(1).expand_as(probs)
27+
probs /= probs.sum(1, True).expand_as(probs)
2828
grad_probs = probs.new().resize_as_(probs).zero_()
2929
output_probs = probs.gather(1, samples)
3030
output_probs.add_(1e-6).reciprocal_()

torch/autograd/_functions/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def backward(ctx, grad_output):
113113
for i in range(ctx.num_unsqueezed):
114114
grad_input = grad_input.sum(0).squeeze(0)
115115
for dim in ctx.expanded_dims:
116-
grad_input = grad_input.sum(dim)
116+
grad_input = grad_input.sum(dim, True)
117117
return grad_input, None
118118

119119

0 commit comments

Comments
 (0)