Skip to content

Commit f6a00fa

Browse files
committed
Add autograd tests for keepdim
1 parent be5191a commit f6a00fa

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/test_autograd.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,8 +1254,10 @@ class dont_convert(tuple):
12541254
(CminConstant, (), ((S, S, S), 0.5)),
12551255
(Mean, (), ((S, S, S),)),
12561256
(Mean, (1,), ((S, S, S),), 'dim', [0]),
1257+
(Mean, (1, True,), ((S, S, S),), 'keepdim_dim', [0]),
12571258
(Sum, (), ((S, S, S),)),
12581259
(Sum, (1,), ((S, S, S),), 'dim', [0]),
1260+
(Sum, (1, True,), ((S, S, S),), 'keepdim_dim', [0]),
12591261
(Prod, (), ((S, S, S),)),
12601262
(Prod, (), (prod_zeros(S, [0, 1]),), 'zerosdim2'),
12611263
(Prod, (), (prod_zeros(S, [0, 2]),), 'zerosdim1'),
@@ -1265,6 +1267,10 @@ class dont_convert(tuple):
12651267
(Prod, (1,), (prod_zeros(S, [0, 1]),), 'zeros_dim2', [0]),
12661268
(Prod, (1,), (prod_zeros(S, [0, 2]),), 'zeros_dim1', [0]),
12671269
(Prod, (1,), (prod_zeros(S, [1, 2]),), 'zeros_dim0', [0]),
1270+
(Prod, (1, True,), ((S, S, S),), 'keepdim_dim', [0]),
1271+
(Prod, (1, True,), (prod_zeros(S, [0, 1]),), 'keepdim_zeros_dim2', [0]),
1272+
(Prod, (1, True,), (prod_zeros(S, [0, 2]),), 'keepdim_zeros_dim1', [0]),
1273+
(Prod, (1, True), (prod_zeros(S, [1, 2]),), 'keepdim_zeros_dim0', [0]),
12681274
(Addmm, (), ((S, M), (S, S), (S, M)),),
12691275
(Addmm, (0.1, 1), ((S, M), (S, S), (S, M)), 'coef'),
12701276
(Addbmm, (), ((S, M), (S, S, S), (S, S, M)),),
@@ -1284,15 +1290,23 @@ class dont_convert(tuple):
12841290
(Min, (), ((S, S, S),),),
12851291
(Max, (1,), ((S, S, S),), 'dim', [0]),
12861292
(Min, (1,), ((S, S, S),), 'dim', [0]),
1293+
(Max, (1, True), ((S, S, S),), 'keepdim_dim', [0]),
1294+
(Min, (1, True), ((S, S, S),), 'keepdim_dim', [0]),
12871295
(Mode, (1,), ((S, S, S),), 'dim', [0]),
1296+
(Mode, (1, True,), ((S, S, S),), 'keepdim_dim', [0]),
12881297
(Kthvalue, (2, 0), ((S, S, S),),),
1298+
(Kthvalue, (2, 0, True), ((S, S, S),), "keepdim"),
12891299
(Median, (0,), ((S, S, S),),),
1300+
(Median, (0, True, ), ((S, S, S),), "keepdim"),
12901301
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
12911302
(Norm, (), ((S, S, S),), '2'),
12921303
(Norm, (3,), ((S, S, S),), '3'),
12931304
(Norm, (1.5, 1), (torch.rand(S, S, S),), '1_5_dim', [1]),
12941305
(Norm, (2, 1), ((S, S, S),), '2_dim', [1]),
12951306
(Norm, (3, 1), ((S, S, S),), '3_dim', [1]),
1307+
(Norm, (1.5, 1, True,), (torch.rand(S, S, S),), 'keepdim_1_5_dim', [1]),
1308+
(Norm, (2, 1, True,), ((S, S, S),), 'keepdim_2_dim', [1]),
1309+
(Norm, (3, 1, True), ((S, S, S),), 'keepdim_3_dim', [1]),
12961310
(Addcmul, (), ((S, S), (S, S), (S, S))),
12971311
(Addcmul, (0.6,), ((S, S), (S, S), (S, S)), 'scale'),
12981312
(Addcdiv, (), ((S, S), (S, S), torch.rand(S, S) + 5e-2)),
@@ -1388,20 +1402,31 @@ class dont_convert(tuple):
13881402
('lerp', (S, S, S), ((S, S, S), 0.4)),
13891403
('max', (S, S, S), ()),
13901404
('max', (S, S, S), (1,), 'dim', [0]),
1405+
('max', (S, S, S), (1, True,), 'keepdim_dim', [0]),
13911406
('max', (S, S, S), ((S, S, S),), 'elementwise'),
13921407
('min', (S, S, S), ()),
13931408
('min', (S, S, S), (1,), 'dim', [0]),
1409+
('min', (S, S, S), (1, True,), 'keepdim_dim', [0]),
13941410
('min', (S, S, S), ((S, S, S),), 'elementwise'),
13951411
('mean', (S, S, S), ()),
13961412
('mean', (S, S, S), (1,), 'dim', [0]),
1413+
('mean', (S, S, S), (1, True,), 'keepdim_dim', [0]),
1414+
('median', (S, S, S), (1,), 'dim', [0]),
1415+
('median', (S, S, S), (1, True,), 'keepdim_dim', [0]),
1416+
('mode', (S, S, S), (1,), 'dim', [0]),
1417+
('mode', (S, S, S), (1, True,), 'keepdim_dim', [0]),
13971418
('sum', (S, S, S), ()),
13981419
('sum', (S, S, S), (1,), 'dim', [0]),
1420+
('sum', (S, S, S), (1, True,), 'keepdim_dim', [0]),
13991421
('prod', (S, S, S), ()),
14001422
('prod', (S, S, S), (1,), 'dim', [0]),
1423+
('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]),
14011424
('var', (S, S, S), ()),
14021425
('var', (S, S, S), (1,), 'dim', [0]),
1426+
('var', (S, S, S), (1, True), 'keepdim_dim', [0]),
14031427
('std', (S, S, S), ()),
14041428
('std', (S, S, S), (1,), 'dim', [0]),
1429+
('std', (S, S, S), (1, True), 'keepdim_dim', [0]),
14051430
('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]),
14061431
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
14071432
('repeat', (S, S, S, S), (2, 3, 1, 4)),
@@ -1424,6 +1449,7 @@ class dont_convert(tuple):
14241449
('addcdiv', (S, S), (0.5, (S, S), (S, S)), 'scale'),
14251450
('norm', (S, S, S), (2,)),
14261451
('norm', (S, S, S), (2, 1), 'dim', [1]),
1452+
('norm', (S, S, S), (2, 1, True), 'keepdim_dim', [0]),
14271453
('dist', (S, S, S), ((S, S, S),)),
14281454
('dist', (S, S, S), ((S, S, S), 4), '4'),
14291455
('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]),

0 commit comments

Comments
 (0)