diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 33cf0a8054f0..fe5f24f74781 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1403,6 +1403,8 @@ def dot(input, other, acc=None, f32_backend=None, max_num_imprecise_acc=None, ou :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second tensor to be multiplied. :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param f32_backend: The backend to use for fp32 x fp32. + :type other: string. Available options for nvidia: :code:`"tf32"`, :code:`"3xtf32"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. """ f32_backend = _constexpr_to_value(f32_backend) out_dtype = _constexpr_to_value(out_dtype) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index d55bcb071417..c1e5747c555d 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1288,7 +1288,7 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope # ===----------------------------------------------------------------------===// -def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, f32_backend: str, max_num_imprecise_acc: int, +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, f32_backend: Optional[str], max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): @@ -1323,6 +1323,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if f32_backend is None: + f32_backend = builder.options.default_f32_backend + else: + assert f32_backend in builder.options.allowed_f32_backends, f"f32 _backend must be one of {builder.options.allowed_f32_backends}. Got {f32_backend}" + lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 5b2da68cf8ae..ad2a2117fef5 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,7 +1,7 @@ from triton.backends.compiler import BaseBackend from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass -from typing import Any +from typing import Any, Tuple import hashlib import tempfile import os @@ -21,6 +21,8 @@ class HIPOptions: debug: bool = False arch: str = None allow_fp8e4nv: bool = False + default_f32_backend: str = "ieee" + allowed_f32_backends: Tuple[str] = ("ieee",) enable_fp_fusion: bool = True capability: int = None # TODO: diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 65bd65eba411..513ea75dac0e 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import functools -from typing import Any +from typing import Any, Tuple import hashlib import re import tempfile @@ -62,6 +62,8 @@ class CUDAOptions: ptx_version: int = None enable_fp_fusion: bool = True allow_fp8e4nv: bool = False + default_f32_backend: str = "tf32" + allowed_f32_backends: Tuple[str] = ("tf32", "3xtf32", "ieee") max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False