diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index 3c85ae8820df..cfb3be5b6d80 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -78,7 +78,7 @@ def __init__(self, device_type: tuple) -> None: def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} args["allow_fp8e4nv"] = self.capability >= 89 - args["max_num_imprecise_acc_default"] = 0 if self.capability >= 89 else None + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 return CUDAOptions(**args) @staticmethod diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index c1ee1036ba6f..8caf75fdc2c4 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1263,11 +1263,10 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - max_num_imprecise_acc = 0 - if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc is None: max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default - if max_num_imprecise_acc is None: - max_num_imprecise_acc = 2**30 + else: + max_num_imprecise_acc = 0 return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)