@@ -155,12 +155,15 @@ def tmp(t):
155155 ('fmod' , small_3d , lambda t : [small_3d_positive (t )], 'tensor' ),
156156 ('chunk' , medium_2d , lambda t : [4 ],),
157157 ('chunk' , medium_2d , lambda t : [4 , 1 ], 'dim' ),
158+ ('chunk' , medium_2d , lambda t : [4 , - 2 ], 'neg_dim' ),
158159 ('clamp' , medium_2d_scaled , lambda t : [- 1 , 5 ],),
159160 ('clone' , medium_2d , lambda t : [],),
160161 ('contiguous' , medium_2d , lambda t : [],),
161162 ('cross' , new_t (M , 3 , M ), lambda t : [new_t (M , 3 , M )(t )],),
162163 ('cumprod' , small_3d , lambda t : [1 ],),
164+ ('cumprod' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
163165 ('cumsum' , small_3d , lambda t : [1 ],),
166+ ('cumsum' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
164167 ('dim' , small_3d , lambda t : [],),
165168 ('dist' , small_2d , lambda t : [small_2d (t )],),
166169 ('dist' , small_2d , lambda t : [small_2d (t ), 3 ], '3_norm' ),
@@ -188,52 +191,72 @@ def tmp(t):
188191 # TODO: positive case
189192 ('kthvalue' , small_3d_unique , lambda t : [3 ],),
190193 ('kthvalue' , small_3d_unique , lambda t : [3 , 1 ], 'dim' ),
194+ ('kthvalue' , small_3d_unique , lambda t : [3 , - 1 ], 'neg_dim' ),
191195 ('lerp' , small_3d , lambda t : [small_3d (t ), 0.3 ],),
192196 ('max' , small_3d_unique , lambda t : [],),
193197 ('max' , small_3d_unique , lambda t : [1 ], 'dim' ),
198+ ('max' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' ),
194199 ('max' , medium_2d , lambda t : [medium_2d (t )], 'elementwise' ),
195200 ('min' , small_3d_unique , lambda t : [],),
196201 ('min' , small_3d_unique , lambda t : [1 ], 'dim' ),
202+ ('min' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' ),
197203 ('min' , medium_2d , lambda t : [medium_2d (t )], 'elementwise' ),
198204 ('mean' , small_3d , lambda t : [],),
205+ ('mean' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
199206 ('mean' , small_3d , lambda t : [1 ], 'dim' ),
200207 ('mode' , small_3d , lambda t : [],),
201208 ('mode' , small_3d , lambda t : [1 ], 'dim' ),
209+ ('mode' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
202210 ('remainder' , small_3d , lambda t : [3 ], 'value' ),
203211 ('remainder' , small_3d , lambda t : [small_3d_positive (t )], 'tensor' ),
204212 ('std' , small_3d , lambda t : [],),
205213 ('std' , small_3d , lambda t : [1 ], 'dim' ),
214+ ('std' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
206215 ('var' , small_3d , lambda t : [],),
207216 ('var' , small_3d , lambda t : [1 ], 'dim' ),
217+ ('var' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
208218 ('ndimension' , small_3d , lambda t : [],),
209219 ('nelement' , small_3d , lambda t : [],),
210220 ('numel' , small_3d , lambda t : [],),
211221 ('narrow' , small_3d , lambda t : [1 , 3 , 2 ],),
222+ ('narrow' , small_3d , lambda t : [- 1 , 3 , 2 ], 'neg_dim' ),
212223 ('nonzero' , small_3d , lambda t : [],),
213224 ('norm' , small_3d , lambda t : [],),
214225 ('norm' , small_3d , lambda t : [3 ], '3_norm' ),
215226 ('norm' , small_3d , lambda t : [3 , 0 ], '3_norm_dim' ),
227+ ('norm' , small_3d , lambda t : [3 , - 2 ], '3_norm_neg_dim' ),
216228 ('ones' , small_3d , lambda t : [1 , 2 , 3 , 4 , 5 ],),
217229 ('permute' , new_t (1 , 2 , 3 , 4 ), lambda t : [2 , 1 , 3 , 0 ],),
218230 ('prod' , small_2d_oneish , lambda t : [],),
219231 ('prod' , small_3d , lambda t : [1 ], 'dim' ),
232+ ('prod' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
220233 ('sum' , small_2d , lambda t : [],),
221234 ('sum' , small_3d , lambda t : [1 ], 'dim' ),
235+ ('sum' , small_3d , lambda t : [- 1 ], 'neg_dim' ),
222236 ('renorm' , small_3d , lambda t : [2 , 1 , 1 ], '2_norm' ),
237+ ('renorm' , small_3d , lambda t : [2 , - 1 , 1 ], '2_norm_neg_dim' ),
223238 ('renorm' , small_3d , lambda t : [1.5 , 1 , 1 ], '1_5_norm' ),
224239 ('repeat' , small_2d , lambda t : [2 , 2 , 2 ],),
225240 ('size' , new_t (1 , 2 , 3 , 4 ), lambda t : [],),
241+ ('size' , new_t (1 , 2 , 3 , 4 ), lambda t : [1 ], 'dim' ),
242+ ('size' , new_t (1 , 2 , 3 , 4 ), lambda t : [- 2 ], 'neg_dim' ),
226243 ('sort' , small_3d_unique , lambda t : [],),
227244 ('sort' , small_3d_unique , lambda t : [1 ], 'dim' ),
245+ ('sort' , small_3d_unique , lambda t : [- 1 ], 'neg_dim' ),
228246 ('sort' , small_3d_unique , lambda t : [1 , True ], 'dim_descending' ),
247+ ('sort' , small_3d_unique , lambda t : [- 1 , True ], 'neg_dim_descending' ),
229248 ('split' , small_3d , lambda t : [2 ],),
230249 ('split' , small_3d , lambda t : [2 , 1 ], 'dim' ),
250+ ('split' , small_3d , lambda t : [2 , - 3 ], 'neg_dim' ),
231251 ('squeeze' , new_t (1 , 2 , 1 , 4 ), lambda t : [],),
232252 ('squeeze' , new_t (1 , 2 , 1 , 4 ), lambda t : [2 ], 'dim' ),
253+ ('squeeze' , new_t (1 , 2 , 1 , 4 ), lambda t : [- 2 ], 'neg_dim' ),
233254 ('t' , new_t (1 , 2 ), lambda t : [],),
234255 ('transpose' , new_t (1 , 2 , 3 , 4 ), lambda t : [1 , 2 ],),
256+ ('transpose' , new_t (1 , 2 , 3 , 4 ), lambda t : [- 1 , - 2 ], 'neg_dim' ),
235257 ('to_list' , small_3d , lambda t : [],),
236258 ('topk' , small_3d , lambda t : [2 , 1 , False , True ], 'dim_sort' ),
259+ ('topk' , small_3d , lambda t : [2 , - 1 , False , True ], 'neg_dim_sort' ),
237260 ('topk' , small_3d , lambda t : [2 , 1 , True , True ], 'dim_desc_sort' ),
238261 ('trace' , medium_2d , lambda t : [],),
239262 ('tril' , medium_2d , lambda t : [],),
@@ -243,6 +266,7 @@ def tmp(t):
243266 ('triu' , medium_2d , lambda t : [2 ], 'positive' ),
244267 ('triu' , medium_2d , lambda t : [- 2 ], 'negative' ),
245268 ('unsqueeze' , new_t (2 , 3 , 4 ), lambda t : [2 ],),
269+ ('unsqueeze' , new_t (2 , 3 , 4 ), lambda t : [- 2 ], 'neg_dim' ),
246270 ('view' , small_3d , lambda t : [100 , 10 ],),
247271 ('view_as' , small_3d , lambda t : [t (100 , 10 )],),
248272 ('zero' , small_3d , lambda t : [],),
@@ -467,6 +491,9 @@ def test_scatter_cpu(self):
467491 def test_scatter_cpu_dim (self ):
468492 self ._test_scatter (torch .randn (4 , 4 ), dim = 1 )
469493
494+ def test_scatter_cpu_neg_dim (self ):
495+ self ._test_scatter (torch .randn (4 , 4 ), dim = - 2 )
496+
470497 def test_scatter_cpu_sizes (self ):
471498 self ._test_scatter (torch .randn (6 , 4 ), chunk_sizes = (2 , 4 ))
472499
@@ -476,6 +503,9 @@ def test_scatter_gpu(self):
476503 def test_scatter_gpu_dim (self ):
477504 self ._test_scatter (torch .randn (4 , 4 ).cuda (), dim = 1 )
478505
506+ def test_scatter_gpu_neg_dim (self ):
507+ self ._test_scatter (torch .randn (4 , 4 ).cuda (), dim = - 2 )
508+
479509 def test_scatter_gpu_sizes (self ):
480510 self ._test_scatter (torch .randn (6 , 4 ).cuda (), chunk_sizes = (2 , 4 ))
481511
0 commit comments