@@ -158,7 +158,7 @@ def _testSelection(self, torchfn, mathfn):
158158
159159 # with indices
160160 m1 = torch .randn (100 , 100 )
161- res1val , res1ind = torchfn (m1 , 1 )
161+ res1val , res1ind = torchfn (m1 , 1 , False )
162162 res2val = m1 [:, 0 :1 ].clone ().squeeze ()
163163 res2ind = res1ind .clone ().fill_ (0 )
164164 for i , j in iter_indices (m1 ):
@@ -206,9 +206,9 @@ def fn(t, dim, keepdim=True):
206206 return ans if not isinstance (ans , tuple ) else ans [0 ]
207207
208208 dim = random .randint (0 , 2 )
209- self .assertEqual (fn (x , dim , True ).unsqueeze (dim ), fn (x , dim ))
210- self .assertEqual (x .ndimension () - 1 , fn (x , dim , True ).ndimension ())
211- self .assertEqual (x .ndimension (), fn (x , dim ).ndimension ())
209+ self .assertEqual (fn (x , dim , False ).unsqueeze (dim ), fn (x , dim ))
210+ self .assertEqual (x .ndimension () - 1 , fn (x , dim , False ).ndimension ())
211+ self .assertEqual (x .ndimension (), fn (x , dim , True ).ndimension ())
212212
213213 # check 1-d behavior
214214 x = torch .randn (1 )
@@ -543,22 +543,22 @@ def test_addbmm(self):
543543 res2 = torch .Tensor ().resize_as_ (res [0 ]).zero_ ()
544544
545545 res2 .addbmm_ (b1 , b2 )
546- self .assertEqual (res2 , res .sum (0 ))
546+ self .assertEqual (res2 , res .sum (0 , False ))
547547
548548 res2 .addbmm_ (1 , b1 , b2 )
549- self .assertEqual (res2 , res .sum (0 ) * 2 )
549+ self .assertEqual (res2 , res .sum (0 , False ) * 2 )
550550
551551 res2 .addbmm_ (1. , .5 , b1 , b2 )
552- self .assertEqual (res2 , res .sum (0 ) * 2.5 )
552+ self .assertEqual (res2 , res .sum (0 , False ) * 2.5 )
553553
554554 res3 = torch .addbmm (1 , res2 , 0 , b1 , b2 )
555555 self .assertEqual (res3 , res2 )
556556
557557 res4 = torch .addbmm (1 , res2 , .5 , b1 , b2 )
558- self .assertEqual (res4 , res .sum (0 ) * 3 )
558+ self .assertEqual (res4 , res .sum (0 , False ) * 3 )
559559
560560 res5 = torch .addbmm (0 , res2 , 1 , b1 , b2 )
561- self .assertEqual (res5 , res .sum (0 ))
561+ self .assertEqual (res5 , res .sum (0 , False ))
562562
563563 res6 = torch .addbmm (.1 , res2 , .5 , b1 , b2 )
564564 self .assertEqual (res6 , res2 * .1 + res .sum (0 ) * .5 )
@@ -1096,7 +1096,7 @@ def test_kthvalue(self):
10961096 x0 = x .clone ()
10971097
10981098 k = random .randint (1 , SIZE )
1099- res1val , res1ind = torch .kthvalue (x , k )
1099+ res1val , res1ind = torch .kthvalue (x , k , False )
11001100 res2val , res2ind = torch .sort (x )
11011101
11021102 self .assertEqual (res1val [:, :], res2val [:, :, k - 1 ], 0 )
@@ -1105,14 +1105,14 @@ def test_kthvalue(self):
11051105 k = random .randint (1 , SIZE )
11061106 res1val = torch .Tensor ()
11071107 res1ind = torch .LongTensor ()
1108- torch .kthvalue (x , k , out = (res1val , res1ind ))
1108+ torch .kthvalue (x , k , False , out = (res1val , res1ind ))
11091109 res2val , res2ind = torch .sort (x )
11101110 self .assertEqual (res1val [:, :], res2val [:, :, k - 1 ], 0 )
11111111 self .assertEqual (res1ind [:, :], res2ind [:, :, k - 1 ], 0 )
11121112
11131113 # test non-default dim
11141114 k = random .randint (1 , SIZE )
1115- res1val , res1ind = torch .kthvalue (x , k , 0 )
1115+ res1val , res1ind = torch .kthvalue (x , k , 0 , False )
11161116 res2val , res2ind = torch .sort (x , 0 )
11171117 self .assertEqual (res1val , res2val [k - 1 ], 0 )
11181118 self .assertEqual (res1ind , res2ind [k - 1 ], 0 )
@@ -1139,7 +1139,7 @@ def test_median(self):
11391139 x = torch .rand (size , size )
11401140 x0 = x .clone ()
11411141
1142- res1val , res1ind = torch .median (x , False )
1142+ res1val , res1ind = torch .median (x , keepdim = False )
11431143 res2val , res2ind = torch .sort (x )
11441144 ind = int (math .floor ((size + 1 ) / 2 ) - 1 )
11451145
@@ -1149,12 +1149,12 @@ def test_median(self):
11491149 # Test use of result tensor
11501150 res2val = torch .Tensor ()
11511151 res2ind = torch .LongTensor ()
1152- torch .median (x , out = (res2val , res2ind ))
1152+ torch .median (x , keepdim = False , out = (res2val , res2ind ))
11531153 self .assertEqual (res2val , res1val , 0 )
11541154 self .assertEqual (res2ind , res1ind , 0 )
11551155
11561156 # Test non-default dim
1157- res1val , res1ind = torch .median (x , 0 )
1157+ res1val , res1ind = torch .median (x , 0 , keepdim = False )
11581158 res2val , res2ind = torch .sort (x , 0 )
11591159 self .assertEqual (res1val , res2val [ind ], 0 )
11601160 self .assertEqual (res1ind , res2ind [ind ], 0 )
@@ -1175,20 +1175,19 @@ def test_mode(self):
11751175 res1ind [0 ] = SIZE - 1
11761176 res1ind [1 ] = SIZE - 1
11771177
1178- res2val , res2ind = torch .mode (x )
1179-
1178+ res2val , res2ind = torch .mode (x , keepdim = False )
11801179 self .assertEqual (res1val , res2val , 0 )
11811180 self .assertEqual (res1ind , res2ind , 0 )
11821181
11831182 # Test use of result tensor
11841183 res2val = torch .Tensor ()
11851184 res2ind = torch .LongTensor ()
1186- torch .mode (x , out = (res2val , res2ind ))
1185+ torch .mode (x , keepdim = False , out = (res2val , res2ind ))
11871186 self .assertEqual (res1val , res2val , 0 )
11881187 self .assertEqual (res1ind , res2ind , 0 )
11891188
11901189 # Test non-default dim
1191- res2val , res2ind = torch .mode (x , 0 )
1190+ res2val , res2ind = torch .mode (x , 0 , False )
11921191 self .assertEqual (res1val , res2val , 0 )
11931192 self .assertEqual (res1ind , res2ind , 0 )
11941193
0 commit comments