Skip to content

Commit a767ca4

Browse files
authored
[FRONTEND] fix semantic issue for max_num_imprecise_acc (triton-lang#2835)
1 parent 1bc9c0e commit a767ca4

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

python/test/unit/language/test_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,6 +2567,31 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
25672567
'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
25682568

25692569

2570+
def test_max_num_imprecise_acc(device):
2571+
capability = torch.cuda.get_device_capability()
2572+
if capability != (9, 0):
2573+
return
2574+
2575+
@triton.jit
2576+
def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
2577+
MAX_NUM_IMPRECISE_ACC: tl.constexpr):
2578+
off_m = tl.arange(0, BLOCK_M)
2579+
off_n = tl.arange(0, BLOCK_N)
2580+
off_k = tl.arange(0, BLOCK_K)
2581+
x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :])
2582+
y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :])
2583+
z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :])
2584+
z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC)
2585+
tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z)
2586+
2587+
M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64
2588+
x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device)
2589+
y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device)
2590+
z = torch.zeros((M, N), dtype=torch.float32, device=device)
2591+
h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps)
2592+
assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC)
2593+
2594+
25702595
@pytest.mark.parametrize('in_dtype', ['float32'])
25712596
def test_dot_mulbroadcastred(in_dtype, device):
25722597
capability = torch.cuda.get_device_capability()

python/triton/language/semantic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,10 +1306,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
13061306
assert acc.type == ret_ty
13071307

13081308
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1309-
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc is None:
1310-
max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
1311-
else:
1312-
max_num_imprecise_acc = 0
1309+
if max_num_imprecise_acc is None:
1310+
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1311+
max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
1312+
else:
1313+
max_num_imprecise_acc = 0
13131314

13141315
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)
13151316

0 commit comments

Comments
 (0)