Skip to content

Commit

Permalink
Add frontend options
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Mar 5, 2024
1 parent ca418ff commit 17ea6d4
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,8 @@ def dot(input, other, acc=None, f32_backend=None, max_num_imprecise_acc=None, ou
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second tensor to be multiplied.
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param f32_backend: The backend to use for fp32 x fp32.
:type other: string. Available options for nvidia: :code:`"tf32"`, :code:`"3xtf32"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`.
"""
f32_backend = _constexpr_to_value(f32_backend)
out_dtype = _constexpr_to_value(out_dtype)
Expand Down
7 changes: 6 additions & 1 deletion python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope
# ===----------------------------------------------------------------------===//


def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, f32_backend: str, max_num_imprecise_acc: int,
def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, f32_backend: Optional[str], max_num_imprecise_acc: int,
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:

def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
Expand Down Expand Up @@ -1323,6 +1323,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):

assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options)

if f32_backend is None:
f32_backend = builder.options.default_f32_backend
else:
assert f32_backend in builder.options.allowed_f32_backends, f"f32 _backend must be one of {builder.options.allowed_f32_backends}. Got {f32_backend}"

lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from triton.backends.compiler import BaseBackend
from triton._C.libtriton import ir, passes, llvm, amd
from dataclasses import dataclass
from typing import Any
from typing import Any, Tuple
import hashlib
import tempfile
import os
Expand All @@ -21,6 +21,8 @@ class HIPOptions:
debug: bool = False
arch: str = None
allow_fp8e4nv: bool = False
default_f32_backend: str = "ieee"
allowed_f32_backends: Tuple[str] = ("ieee",)
enable_fp_fusion: bool = True
capability: int = None
# TODO:
Expand Down
4 changes: 3 additions & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
import functools
from typing import Any
from typing import Any, Tuple
import hashlib
import re
import tempfile
Expand Down Expand Up @@ -62,6 +62,8 @@ class CUDAOptions:
ptx_version: int = None
enable_fp_fusion: bool = True
allow_fp8e4nv: bool = False
default_f32_backend: str = "tf32"
allowed_f32_backends: Tuple[str] = ("tf32", "3xtf32", "ieee")
max_num_imprecise_acc_default: bool = None
extern_libs: dict = None
debug: bool = False
Expand Down

0 comments on commit 17ea6d4

Please sign in to comment.