@@ -1254,8 +1254,10 @@ class dont_convert(tuple):
12541254 (CminConstant , (), ((S , S , S ), 0.5 )),
12551255 (Mean , (), ((S , S , S ),)),
12561256 (Mean , (1 ,), ((S , S , S ),), 'dim' , [0 ]),
1257+ (Mean , (1 , True ,), ((S , S , S ),), 'keepdim_dim' , [0 ]),
12571258 (Sum , (), ((S , S , S ),)),
12581259 (Sum , (1 ,), ((S , S , S ),), 'dim' , [0 ]),
1260+ (Sum , (1 , True ,), ((S , S , S ),), 'keepdim_dim' , [0 ]),
12591261 (Prod , (), ((S , S , S ),)),
12601262 (Prod , (), (prod_zeros (S , [0 , 1 ]),), 'zerosdim2' ),
12611263 (Prod , (), (prod_zeros (S , [0 , 2 ]),), 'zerosdim1' ),
@@ -1265,6 +1267,10 @@ class dont_convert(tuple):
12651267 (Prod , (1 ,), (prod_zeros (S , [0 , 1 ]),), 'zeros_dim2' , [0 ]),
12661268 (Prod , (1 ,), (prod_zeros (S , [0 , 2 ]),), 'zeros_dim1' , [0 ]),
12671269 (Prod , (1 ,), (prod_zeros (S , [1 , 2 ]),), 'zeros_dim0' , [0 ]),
1270+ (Prod , (1 , True ,), ((S , S , S ),), 'keepdim_dim' , [0 ]),
1271+ (Prod , (1 , True ,), (prod_zeros (S , [0 , 1 ]),), 'keepdim_zeros_dim2' , [0 ]),
1272+ (Prod , (1 , True ,), (prod_zeros (S , [0 , 2 ]),), 'keepdim_zeros_dim1' , [0 ]),
1273+ (Prod , (1 , True ), (prod_zeros (S , [1 , 2 ]),), 'keepdim_zeros_dim0' , [0 ]),
12681274 (Addmm , (), ((S , M ), (S , S ), (S , M )),),
12691275 (Addmm , (0.1 , 1 ), ((S , M ), (S , S ), (S , M )), 'coef' ),
12701276 (Addbmm , (), ((S , M ), (S , S , S ), (S , S , M )),),
@@ -1284,15 +1290,23 @@ class dont_convert(tuple):
12841290 (Min , (), ((S , S , S ),),),
12851291 (Max , (1 ,), ((S , S , S ),), 'dim' , [0 ]),
12861292 (Min , (1 ,), ((S , S , S ),), 'dim' , [0 ]),
1293+ (Max , (1 , True ), ((S , S , S ),), 'keepdim_dim' , [0 ]),
1294+ (Min , (1 , True ), ((S , S , S ),), 'keepdim_dim' , [0 ]),
12871295 (Mode , (1 ,), ((S , S , S ),), 'dim' , [0 ]),
1296+ (Mode , (1 , True ,), ((S , S , S ),), 'keepdim_dim' , [0 ]),
12881297 (Kthvalue , (2 , 0 ), ((S , S , S ),),),
1298+ (Kthvalue , (2 , 0 , True ), ((S , S , S ),), "keepdim" ),
12891299 (Median , (0 ,), ((S , S , S ),),),
1300+ (Median , (0 , True , ), ((S , S , S ),), "keepdim" ),
12901301 (Norm , (1.5 ,), (torch .rand (S , S , S ),), '1_5' ),
12911302 (Norm , (), ((S , S , S ),), '2' ),
12921303 (Norm , (3 ,), ((S , S , S ),), '3' ),
12931304 (Norm , (1.5 , 1 ), (torch .rand (S , S , S ),), '1_5_dim' , [1 ]),
12941305 (Norm , (2 , 1 ), ((S , S , S ),), '2_dim' , [1 ]),
12951306 (Norm , (3 , 1 ), ((S , S , S ),), '3_dim' , [1 ]),
1307+ (Norm , (1.5 , 1 , True ,), (torch .rand (S , S , S ),), 'keepdim_1_5_dim' , [1 ]),
1308+ (Norm , (2 , 1 , True ,), ((S , S , S ),), 'keepdim_2_dim' , [1 ]),
1309+ (Norm , (3 , 1 , True ), ((S , S , S ),), 'keepdim_3_dim' , [1 ]),
12961310 (Addcmul , (), ((S , S ), (S , S ), (S , S ))),
12971311 (Addcmul , (0.6 ,), ((S , S ), (S , S ), (S , S )), 'scale' ),
12981312 (Addcdiv , (), ((S , S ), (S , S ), torch .rand (S , S ) + 5e-2 )),
@@ -1388,20 +1402,31 @@ class dont_convert(tuple):
13881402 ('lerp' , (S , S , S ), ((S , S , S ), 0.4 )),
13891403 ('max' , (S , S , S ), ()),
13901404 ('max' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1405+ ('max' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
13911406 ('max' , (S , S , S ), ((S , S , S ),), 'elementwise' ),
13921407 ('min' , (S , S , S ), ()),
13931408 ('min' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1409+ ('min' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
13941410 ('min' , (S , S , S ), ((S , S , S ),), 'elementwise' ),
13951411 ('mean' , (S , S , S ), ()),
13961412 ('mean' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1413+ ('mean' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
1414+ ('median' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1415+ ('median' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
1416+ ('mode' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1417+ ('mode' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
13971418 ('sum' , (S , S , S ), ()),
13981419 ('sum' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1420+ ('sum' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
13991421 ('prod' , (S , S , S ), ()),
14001422 ('prod' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1423+ ('prod' , (S , S , S ), (1 , True ,), 'keepdim_dim' , [0 ]),
14011424 ('var' , (S , S , S ), ()),
14021425 ('var' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1426+ ('var' , (S , S , S ), (1 , True ), 'keepdim_dim' , [0 ]),
14031427 ('std' , (S , S , S ), ()),
14041428 ('std' , (S , S , S ), (1 ,), 'dim' , [0 ]),
1429+ ('std' , (S , S , S ), (1 , True ), 'keepdim_dim' , [0 ]),
14051430 ('renorm' , (S , S , S ), (2 , 1 , 0.5 ), 'dim' , [1 ]),
14061431 ('renorm' , (S , S , S ), (1 , 2 , 3 ), 'norm_1' ),
14071432 ('repeat' , (S , S , S , S ), (2 , 3 , 1 , 4 )),
@@ -1424,6 +1449,7 @@ class dont_convert(tuple):
14241449 ('addcdiv' , (S , S ), (0.5 , (S , S ), (S , S )), 'scale' ),
14251450 ('norm' , (S , S , S ), (2 ,)),
14261451 ('norm' , (S , S , S ), (2 , 1 ), 'dim' , [1 ]),
1452+ ('norm' , (S , S , S ), (2 , 1 , True ), 'keepdim_dim' , [0 ]),
14271453 ('dist' , (S , S , S ), ((S , S , S ),)),
14281454 ('dist' , (S , S , S ), ((S , S , S ), 4 ), '4' ),
14291455 ('index_select' , (S , S , S ), (0 , index_variable (2 , S )), 'dim' , [0 ]),
0 commit comments