Skip to content

Commit 39f4473

Browse files
authored
[FRONTEND] fix default max_num_imprecise_acc (triton-lang#2804)
1 parent b442db2 commit 39f4473

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

python/triton/compiler/backends/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, device_type: tuple) -> None:
7878
def parse_options(self, opts) -> Any:
7979
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
8080
args["allow_fp8e4nv"] = self.capability >= 89
81-
args["max_num_imprecise_acc_default"] = 0 if self.capability >= 89 else None
81+
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
8282
return CUDAOptions(**args)
8383

8484
@staticmethod

python/triton/language/semantic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,11 +1276,10 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
12761276
assert acc.type == ret_ty
12771277

12781278
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1279-
max_num_imprecise_acc = 0
1280-
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1279+
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc is None:
12811280
max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
1282-
if max_num_imprecise_acc is None:
1283-
max_num_imprecise_acc = 2**30
1281+
else:
1282+
max_num_imprecise_acc = 0
12841283

12851284
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)
12861285

0 commit comments

Comments
 (0)