@@ -1688,10 +1688,11 @@ def kernel(X, Z, BLOCK: tl.constexpr):
16881688 for shape in reduce3d_shapes
16891689 for axis in [0 , 1 , 2 ]]
16901690invalid_config = [('sum' , 'float32' , (32 , 32 ), axis ) for axis in [2 , 3 ]]
1691+ negative_config = [('sum' , 'float32' , (32 , 32 ), - 1 )]
16911692
16921693
16931694@pytest .mark .parametrize ("op, dtype_str, shape, axis" ,
1694- reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config )
1695+ reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + negative_config )
16951696@pytest .mark .parametrize ("num_ctas" , num_ctas_list )
16961697def test_reduce (op , dtype_str , shape , axis , num_ctas , device ):
16971698 check_type_supported (dtype_str , device ) # bfloat16 on cc < 80 will not be tested
@@ -1739,8 +1740,8 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
17391740 z_dtype_str = get_reduced_dtype (dtype_str , op )
17401741 z_tri_dtype_str = z_dtype_str
17411742 # triton result
1742- ret_numel = 1 if axis is None else shape [ 1 - axis ]
1743- z_shape = (1 , ) if axis is None else tuple (shape_i for i , shape_i in enumerate (shape ) if i != axis )
1743+ non_negative_axis = axis if axis is None or axis >= 0 else len ( shape ) + axis
1744+ z_shape = (1 , ) if axis is None else tuple (shape_i for i , shape_i in enumerate (shape ) if i != non_negative_axis )
17441745 z_tri = to_triton (numpy_random (z_shape , dtype_str = z_dtype_str , rs = rs ), device = device , dst_type = z_tri_dtype_str )
17451746 BLOCK_K = 1 if len (shape ) == 2 else shape [2 ]
17461747 IS_3D = bool (len (shape ) == 3 )
@@ -1787,6 +1788,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
17871788 for axis in [1 , 0 ]
17881789 for shape in scan2d_shapes
17891790 for op in ['cumsum' , 'cumprod' , 'get_first_element' ]]
1791+ negative_config = [('cumsum' , 'float32' , (32 , 32 ), - 1 , 4 )]
17901792
17911793
17921794@triton .jit
@@ -1795,7 +1797,7 @@ def get_first_element(a, b):
17951797 return a
17961798
17971799
1798- @pytest .mark .parametrize ("op, dtype_str, shape, axis, num_warps" , scan_configs )
1800+ @pytest .mark .parametrize ("op, dtype_str, shape, axis, num_warps" , scan_configs + negative_config )
17991801def test_scan2d (op , dtype_str , shape , axis , num_warps , device ):
18001802 if is_hip ():
18011803 pytest .skip ("test_scan2d is not supported in HIP" )
0 commit comments