Skip to content

Commit f107df1

Browse files
authored
[FRONTEND] Support negative axis for reduce and scan (triton-lang#2849)
1 parent 5c8eaf9 commit f107df1

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/test/unit/language/test_core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,10 +1688,11 @@ def kernel(X, Z, BLOCK: tl.constexpr):
16881688
for shape in reduce3d_shapes
16891689
for axis in [0, 1, 2]]
16901690
invalid_config = [('sum', 'float32', (32, 32), axis) for axis in [2, 3]]
1691+
negative_config = [('sum', 'float32', (32, 32), -1)]
16911692

16921693

16931694
@pytest.mark.parametrize("op, dtype_str, shape, axis",
1694-
reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config)
1695+
reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + negative_config)
16951696
@pytest.mark.parametrize("num_ctas", num_ctas_list)
16961697
def test_reduce(op, dtype_str, shape, axis, num_ctas, device):
16971698
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
@@ -1739,8 +1740,8 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
17391740
z_dtype_str = get_reduced_dtype(dtype_str, op)
17401741
z_tri_dtype_str = z_dtype_str
17411742
# triton result
1742-
ret_numel = 1 if axis is None else shape[1 - axis]
1743-
z_shape = (1, ) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis)
1743+
non_negative_axis = axis if axis is None or axis >= 0 else len(shape) + axis
1744+
z_shape = (1, ) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != non_negative_axis)
17441745
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
17451746
BLOCK_K = 1 if len(shape) == 2 else shape[2]
17461747
IS_3D = bool(len(shape) == 3)
@@ -1787,6 +1788,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
17871788
for axis in [1, 0]
17881789
for shape in scan2d_shapes
17891790
for op in ['cumsum', 'cumprod', 'get_first_element']]
1791+
negative_config = [('cumsum', 'float32', (32, 32), -1, 4)]
17901792

17911793

17921794
@triton.jit
@@ -1795,7 +1797,7 @@ def get_first_element(a, b):
17951797
return a
17961798

17971799

1798-
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
1800+
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs + negative_config)
17991801
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
18001802
if is_hip():
18011803
pytest.skip("test_scan2d is not supported in HIP")

python/triton/language/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1477,8 +1477,9 @@ def make_combine_region(reduce_op):
14771477
handles = [r.handle for r in results]
14781478
_builder.create_reduce_ret(*handles)
14791479

1480+
axis = _constexpr_to_value(axis)
14801481
if axis is not None:
1481-
axis = _constexpr_to_value(axis)
1482+
axis = _wrap_axis(axis, len(input[0].shape))
14821483
return semantic.reduction(input, axis, make_combine_region, _builder)
14831484

14841485

@@ -1557,6 +1558,8 @@ def make_combine_region(scan_op):
15571558
_builder.create_scan_ret(*handles)
15581559

15591560
axis = _constexpr_to_value(axis)
1561+
if axis is not None:
1562+
axis = _wrap_axis(axis, len(input[0].shape))
15601563
return semantic.associative_scan(input, axis, make_combine_region, _builder)
15611564

15621565

0 commit comments

Comments
 (0)