Skip to content

Commit c4742fd

Browse files
committed
Explicitly pass keepdim=False for tests that require it.
If we change the default to False, reverting this commit is optional.
1 parent e124790 commit c4742fd

File tree

2 files changed

+22
-23
lines changed

2 files changed

+22
-23
lines changed

test/test_legacy_nn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,19 @@ def _do_test(self, test_case, module, input):
233233
reference_fn=lambda i, _: torch.bmm(i[0], i[1].view(i[1].size(0), i[1].size(1), 1)).squeeze()),
234234
OldModuleTest(nn.Max,
235235
input_size=(4, 5, 3),
236-
reference_fn=lambda i, _: torch.max(i, 0)[0]),
236+
reference_fn=lambda i, _: torch.max(i, 0, False)[0]),
237237
OldModuleTest(nn.Max,
238238
(1,),
239239
input_size=(4, 5, 3),
240-
reference_fn=lambda i, _: torch.max(i, 1)[0],
240+
reference_fn=lambda i, _: torch.max(i, 1, False)[0],
241241
desc='with_dimension'),
242242
OldModuleTest(nn.Min,
243243
input_size=(4, 5, 3),
244-
reference_fn=lambda i, _: torch.min(i, 0)[0]),
244+
reference_fn=lambda i, _: torch.min(i, 0, False)[0]),
245245
OldModuleTest(nn.Min,
246246
(1,),
247247
input_size=(4, 5, 3),
248-
reference_fn=lambda i, _: torch.min(i, 1)[0],
248+
reference_fn=lambda i, _: torch.min(i, 1, False)[0],
249249
desc='with_dimension'),
250250
OldModuleTest(nn.MixtureTable,
251251
tuple(),

test/test_torch.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _testSelection(self, torchfn, mathfn):
158158

159159
# with indices
160160
m1 = torch.randn(100, 100)
161-
res1val, res1ind = torchfn(m1, 1)
161+
res1val, res1ind = torchfn(m1, 1, False)
162162
res2val = m1[:, 0:1].clone().squeeze()
163163
res2ind = res1ind.clone().fill_(0)
164164
for i, j in iter_indices(m1):
@@ -206,9 +206,9 @@ def fn(t, dim, keepdim=True):
206206
return ans if not isinstance(ans, tuple) else ans[0]
207207

208208
dim = random.randint(0, 2)
209-
self.assertEqual(fn(x, dim, True).unsqueeze(dim), fn(x, dim))
210-
self.assertEqual(x.ndimension() - 1, fn(x, dim, True).ndimension())
211-
self.assertEqual(x.ndimension(), fn(x, dim).ndimension())
209+
self.assertEqual(fn(x, dim, False).unsqueeze(dim), fn(x, dim))
210+
self.assertEqual(x.ndimension() - 1, fn(x, dim, False).ndimension())
211+
self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())
212212

213213
# check 1-d behavior
214214
x = torch.randn(1)
@@ -543,22 +543,22 @@ def test_addbmm(self):
543543
res2 = torch.Tensor().resize_as_(res[0]).zero_()
544544

545545
res2.addbmm_(b1, b2)
546-
self.assertEqual(res2, res.sum(0))
546+
self.assertEqual(res2, res.sum(0, False))
547547

548548
res2.addbmm_(1, b1, b2)
549-
self.assertEqual(res2, res.sum(0) * 2)
549+
self.assertEqual(res2, res.sum(0, False) * 2)
550550

551551
res2.addbmm_(1., .5, b1, b2)
552-
self.assertEqual(res2, res.sum(0) * 2.5)
552+
self.assertEqual(res2, res.sum(0, False) * 2.5)
553553

554554
res3 = torch.addbmm(1, res2, 0, b1, b2)
555555
self.assertEqual(res3, res2)
556556

557557
res4 = torch.addbmm(1, res2, .5, b1, b2)
558-
self.assertEqual(res4, res.sum(0) * 3)
558+
self.assertEqual(res4, res.sum(0, False) * 3)
559559

560560
res5 = torch.addbmm(0, res2, 1, b1, b2)
561-
self.assertEqual(res5, res.sum(0))
561+
self.assertEqual(res5, res.sum(0, False))
562562

563563
res6 = torch.addbmm(.1, res2, .5, b1, b2)
564564
self.assertEqual(res6, res2 * .1 + res.sum(0) * .5)
@@ -1096,7 +1096,7 @@ def test_kthvalue(self):
10961096
x0 = x.clone()
10971097

10981098
k = random.randint(1, SIZE)
1099-
res1val, res1ind = torch.kthvalue(x, k)
1099+
res1val, res1ind = torch.kthvalue(x, k, False)
11001100
res2val, res2ind = torch.sort(x)
11011101

11021102
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
@@ -1105,14 +1105,14 @@ def test_kthvalue(self):
11051105
k = random.randint(1, SIZE)
11061106
res1val = torch.Tensor()
11071107
res1ind = torch.LongTensor()
1108-
torch.kthvalue(x, k, out=(res1val, res1ind))
1108+
torch.kthvalue(x, k, False, out=(res1val, res1ind))
11091109
res2val, res2ind = torch.sort(x)
11101110
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
11111111
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
11121112

11131113
# test non-default dim
11141114
k = random.randint(1, SIZE)
1115-
res1val, res1ind = torch.kthvalue(x, k, 0)
1115+
res1val, res1ind = torch.kthvalue(x, k, 0, False)
11161116
res2val, res2ind = torch.sort(x, 0)
11171117
self.assertEqual(res1val, res2val[k - 1], 0)
11181118
self.assertEqual(res1ind, res2ind[k - 1], 0)
@@ -1139,7 +1139,7 @@ def test_median(self):
11391139
x = torch.rand(size, size)
11401140
x0 = x.clone()
11411141

1142-
res1val, res1ind = torch.median(x, False)
1142+
res1val, res1ind = torch.median(x, keepdim=False)
11431143
res2val, res2ind = torch.sort(x)
11441144
ind = int(math.floor((size + 1) / 2) - 1)
11451145

@@ -1149,12 +1149,12 @@ def test_median(self):
11491149
# Test use of result tensor
11501150
res2val = torch.Tensor()
11511151
res2ind = torch.LongTensor()
1152-
torch.median(x, out=(res2val, res2ind))
1152+
torch.median(x, keepdim=False, out=(res2val, res2ind))
11531153
self.assertEqual(res2val, res1val, 0)
11541154
self.assertEqual(res2ind, res1ind, 0)
11551155

11561156
# Test non-default dim
1157-
res1val, res1ind = torch.median(x, 0)
1157+
res1val, res1ind = torch.median(x, 0, keepdim=False)
11581158
res2val, res2ind = torch.sort(x, 0)
11591159
self.assertEqual(res1val, res2val[ind], 0)
11601160
self.assertEqual(res1ind, res2ind[ind], 0)
@@ -1175,20 +1175,19 @@ def test_mode(self):
11751175
res1ind[0] = SIZE - 1
11761176
res1ind[1] = SIZE - 1
11771177

1178-
res2val, res2ind = torch.mode(x)
1179-
1178+
res2val, res2ind = torch.mode(x, keepdim=False)
11801179
self.assertEqual(res1val, res2val, 0)
11811180
self.assertEqual(res1ind, res2ind, 0)
11821181

11831182
# Test use of result tensor
11841183
res2val = torch.Tensor()
11851184
res2ind = torch.LongTensor()
1186-
torch.mode(x, out=(res2val, res2ind))
1185+
torch.mode(x, keepdim=False, out=(res2val, res2ind))
11871186
self.assertEqual(res1val, res2val, 0)
11881187
self.assertEqual(res1ind, res2ind, 0)
11891188

11901189
# Test non-default dim
1191-
res2val, res2ind = torch.mode(x, 0)
1190+
res2val, res2ind = torch.mode(x, 0, False)
11921191
self.assertEqual(res1val, res2val, 0)
11931192
self.assertEqual(res1ind, res2ind, 0)
11941193

0 commit comments

Comments
 (0)