@@ -159,17 +159,17 @@ def _testSelection(self, torchfn, mathfn):
159159 # with indices
160160 m1 = torch .randn (100 , 100 )
161161 res1val , res1ind = torchfn (m1 , 1 )
162- res2val = m1 [:, 0 :1 ].clone ()
162+ res2val = m1 [:, 0 :1 ].clone (). squeeze ()
163163 res2ind = res1ind .clone ().fill_ (0 )
164164 for i , j in iter_indices (m1 ):
165- if mathfn (res2val [i , 0 ], m1 [i , j ]) != res2val [i , 0 ]:
166- res2val [i , 0 ] = m1 [i , j ]
167- res2ind [i , 0 ] = j
165+ if mathfn (res2val [i ], m1 [i , j ]) != res2val [i ]:
166+ res2val [i ] = m1 [i , j ]
167+ res2ind [i ] = j
168168
169169 maxerr = 0
170170 for i in range (res1val .size (0 )):
171- maxerr = max (maxerr , abs (res1val [i ][ 0 ] - res2val [i ][ 0 ]))
172- self .assertEqual (res1ind [i ][ 0 ] , res2ind [i ][ 0 ])
171+ maxerr = max (maxerr , abs (res1val [i ] - res2val [i ]))
172+ self .assertEqual (res1ind [i ], res2ind [i ])
173173 self .assertLessEqual (abs (maxerr ), 1e-5 )
174174
175175 # NaNs
@@ -514,22 +514,22 @@ def test_addbmm(self):
514514 res2 = torch .Tensor ().resize_as_ (res [0 ]).zero_ ()
515515
516516 res2 .addbmm_ (b1 , b2 )
517- self .assertEqual (res2 , res .sum (0 )[ 0 ] )
517+ self .assertEqual (res2 , res .sum (0 ))
518518
519519 res2 .addbmm_ (1 , b1 , b2 )
520- self .assertEqual (res2 , res .sum (0 )[ 0 ] * 2 )
520+ self .assertEqual (res2 , res .sum (0 ) * 2 )
521521
522522 res2 .addbmm_ (1. , .5 , b1 , b2 )
523- self .assertEqual (res2 , res .sum (0 )[ 0 ] * 2.5 )
523+ self .assertEqual (res2 , res .sum (0 ) * 2.5 )
524524
525525 res3 = torch .addbmm (1 , res2 , 0 , b1 , b2 )
526526 self .assertEqual (res3 , res2 )
527527
528528 res4 = torch .addbmm (1 , res2 , .5 , b1 , b2 )
529- self .assertEqual (res4 , res .sum (0 )[ 0 ] * 3 )
529+ self .assertEqual (res4 , res .sum (0 ) * 3 )
530530
531531 res5 = torch .addbmm (0 , res2 , 1 , b1 , b2 )
532- self .assertEqual (res5 , res .sum (0 )[ 0 ] )
532+ self .assertEqual (res5 , res .sum (0 ))
533533
534534 res6 = torch .addbmm (.1 , res2 , .5 , b1 , b2 )
535535 self .assertEqual (res6 , res2 * .1 + res .sum (0 ) * .5 )
@@ -744,7 +744,7 @@ def renorm(matrix, value, dim, max_norm):
744744 m1 = matrix .transpose (dim , 0 ).contiguous ()
745745 # collapse non-dim dimensions.
746746 m2 = m1 .clone ().resize_ (m1 .size (0 ), int (math .floor (m1 .nelement () / m1 .size (0 ))))
747- norms = m2 .norm (value , 1 )
747+ norms = m2 .norm (value , 1 , True )
748748 # clip
749749 new_norms = norms .clone ()
750750 new_norms [torch .gt (norms , max_norm )] = max_norm
@@ -1070,23 +1070,23 @@ def test_kthvalue(self):
10701070 res1val , res1ind = torch .kthvalue (x , k )
10711071 res2val , res2ind = torch .sort (x )
10721072
1073- self .assertEqual (res1val [:, :, 0 ], res2val [:, :, k - 1 ], 0 )
1074- self .assertEqual (res1ind [:, :, 0 ], res2ind [:, :, k - 1 ], 0 )
1073+ self .assertEqual (res1val [:, :], res2val [:, :, k - 1 ], 0 )
1074+ self .assertEqual (res1ind [:, :], res2ind [:, :, k - 1 ], 0 )
10751075 # test use of result tensors
10761076 k = random .randint (1 , SIZE )
10771077 res1val = torch .Tensor ()
10781078 res1ind = torch .LongTensor ()
10791079 torch .kthvalue (x , k , out = (res1val , res1ind ))
10801080 res2val , res2ind = torch .sort (x )
1081- self .assertEqual (res1val [:, :, 0 ], res2val [:, :, k - 1 ], 0 )
1082- self .assertEqual (res1ind [:, :, 0 ], res2ind [:, :, k - 1 ], 0 )
1081+ self .assertEqual (res1val [:, :], res2val [:, :, k - 1 ], 0 )
1082+ self .assertEqual (res1ind [:, :], res2ind [:, :, k - 1 ], 0 )
10831083
10841084 # test non-default dim
10851085 k = random .randint (1 , SIZE )
10861086 res1val , res1ind = torch .kthvalue (x , k , 0 )
10871087 res2val , res2ind = torch .sort (x , 0 )
1088- self .assertEqual (res1val [ 0 ] , res2val [k - 1 ], 0 )
1089- self .assertEqual (res1ind [ 0 ] , res2ind [k - 1 ], 0 )
1088+ self .assertEqual (res1val , res2val [k - 1 ], 0 )
1089+ self .assertEqual (res1ind , res2ind [k - 1 ], 0 )
10901090
10911091 # non-contiguous
10921092 y = x .narrow (1 , 0 , 1 )
@@ -1110,12 +1110,12 @@ def test_median(self):
11101110 x = torch .rand (size , size )
11111111 x0 = x .clone ()
11121112
1113- res1val , res1ind = torch .median (x )
1113+ res1val , res1ind = torch .median (x , False )
11141114 res2val , res2ind = torch .sort (x )
11151115 ind = int (math .floor ((size + 1 ) / 2 ) - 1 )
11161116
1117- self .assertEqual (res2val .select (1 , ind ), res1val . select ( 1 , 0 ) , 0 )
1118- self .assertEqual (res2val .select (1 , ind ), res1val . select ( 1 , 0 ) , 0 )
1117+ self .assertEqual (res2val .select (1 , ind ), res1val , 0 )
1118+ self .assertEqual (res2val .select (1 , ind ), res1val , 0 )
11191119
11201120 # Test use of result tensor
11211121 res2val = torch .Tensor ()
@@ -1127,8 +1127,8 @@ def test_median(self):
11271127 # Test non-default dim
11281128 res1val , res1ind = torch .median (x , 0 )
11291129 res2val , res2ind = torch .sort (x , 0 )
1130- self .assertEqual (res1val [ 0 ] , res2val [ind ], 0 )
1131- self .assertEqual (res1ind [ 0 ] , res2ind [ind ], 0 )
1130+ self .assertEqual (res1val , res2val [ind ], 0 )
1131+ self .assertEqual (res1ind , res2ind [ind ], 0 )
11321132
11331133 # input unchanged
11341134 self .assertEqual (x , x0 , 0 )
@@ -1140,9 +1140,9 @@ def test_mode(self):
11401140 x0 = x .clone ()
11411141
11421142 # Pre-calculated results.
1143- res1val = torch .Tensor (SIZE , 1 ).fill_ (1 )
1143+ res1val = torch .Tensor (SIZE ).fill_ (1 )
11441144 # The indices are the position of the last appearance of the mode element.
1145- res1ind = torch .LongTensor (SIZE , 1 ).fill_ (1 )
1145+ res1ind = torch .LongTensor (SIZE ).fill_ (1 )
11461146 res1ind [0 ] = SIZE - 1
11471147 res1ind [1 ] = SIZE - 1
11481148
@@ -1160,8 +1160,8 @@ def test_mode(self):
11601160
11611161 # Test non-default dim
11621162 res2val , res2ind = torch .mode (x , 0 )
1163- self .assertEqual (res1val . view ( 1 , SIZE ) , res2val , 0 )
1164- self .assertEqual (res1ind . view ( 1 , SIZE ) , res2ind , 0 )
1163+ self .assertEqual (res1val , res2val , 0 )
1164+ self .assertEqual (res1ind , res2ind , 0 )
11651165
11661166 # input unchanged
11671167 self .assertEqual (x , x0 , 0 )
@@ -2217,7 +2217,7 @@ def _test_gather(self, cast, test_bounds=True):
22172217 self .assertRaises (RuntimeError , lambda : torch .gather (src , dim , idx ))
22182218
22192219 src = cast (torch .randn (3 , 4 , 5 ))
2220- expected , idx = src .max (2 )
2220+ expected , idx = src .max (2 , True )
22212221 expected = cast (expected )
22222222 idx = cast (idx )
22232223 actual = torch .gather (src , 2 , idx )
0 commit comments