Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tl.const annotation to mark and validate that const tensors are not being stored to #3360

Merged
merged 11 commits into from
Mar 14, 2024
61 changes: 61 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3193,6 +3193,67 @@ def kernel(out_ptr):
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None


@triton.jit
def pass_const(a, b, choose_b):
if choose_b:
return b
else:
return a


@pytest.mark.parametrize("choose_const", [True, False])
@pytest.mark.parametrize("constexpr", [True, False])
@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"])
def test_const(device, choose_const, constexpr, mode):

@triton.jit(do_not_specialize=["choose_const"])
def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elems
val = tl.load(in_ptr + offsets, mask=mask)
LOSE_TAIL
tl.store(final_out + offsets, val, mask=mask)

@triton.jit
def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32,
BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elems
val = tl.load(in_ptr + offsets, mask=mask)
LOSE_TAIL
tl.store(final_out + offsets, val, mask=mask)

if mode == "direct":
if choose_const:
LOSE_TAIL = "final_out = c_out"
else:
LOSE_TAIL = "final_out = out"
elif mode == "call":
LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)"
elif mode == "ternary":
LOSE_TAIL = "final_out = c_out if choose_const else out"
elif mode == "if":
LOSE_TAIL = """
if choose_const:
final_out = c_out
else:
final_out = out
"""

SIZE = 128
input = torch.randn((SIZE, ), dtype=torch.float32, device=device)
output = torch.zeros((SIZE, ), dtype=torch.float32, device=device)
patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''})

expect_fail = (not constexpr and mode != "direct") or choose_const
if expect_fail:
with pytest.raises(triton.CompilationError) as exc_info:
patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE)
else:
patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE)
assert torch.all(input == output)


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str, device):
Expand Down
11 changes: 10 additions & 1 deletion python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
broadcast_to,
cat,
clamp,
const,
const_pointer_type,
constexpr,
debug_barrier,
device_assert,
Expand Down Expand Up @@ -142,6 +144,8 @@
"cat",
"cdiv",
"clamp",
"const",
"const_pointer_type",
"constexpr",
"cos",
"cumprod",
Expand Down Expand Up @@ -245,7 +249,12 @@

def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
name = name[1:]
if name[0] == "k":
name = name[1:]
ty = str_to_ty(name)
return const_pointer_type(ty)
ty = str_to_ty(name)
return pointer_type(ty)
tys = {
"fp8e4nv": float8e4nv,
Expand Down
32 changes: 32 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def is_block():
def is_ptr():
return False

@staticmethod
def is_const():
return False

def __eq__(self, other: dtype):
if not isinstance(other, dtype):
return False
Expand Down Expand Up @@ -419,6 +423,23 @@ def scalar(self):
return self


class const_pointer_type(pointer_type):
pawelszczerbuk marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, element_ty: dtype, address_space: int = 1):
super().__init__(element_ty, address_space)

def __str__(self):
return f'const_pointer<{self.element_ty}>'

def is_const(self):
return True

def __eq__(self, other) -> bool:
if not isinstance(other, const_pointer_type):
return False
return self.element_ty == other.element_ty and self.address_space == other.address_space


class block_type(dtype):

def __init__(self, element_ty: dtype, shape: List):
Expand Down Expand Up @@ -514,6 +535,17 @@ def to_ir(self, builder: ir.builder):
# -----------------------


class const:
"""
This class is used as a type annotation to mark pointers to tensors
that cannot be modified. Store cannot be called with a const pointer.
Const pointers are represented by `const_pointer_type`, and usual
Triton type consistency rules apply - you cannot return a const pointer
from a function that returns also a non-const pointer, etc.
"""
pawelszczerbuk marked this conversation as resolved.
Show resolved Hide resolved
pass


class constexpr:
"""
This class is used to store a value that is known at compile-time.
Expand Down
5 changes: 5 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,9 @@ def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_ch
cache = _str_to_store_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)

if ptr.type.is_const() or ptr.type.scalar.is_const():
raise ValueError("Cannot store to a constant pointer")

if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# Store by a block pointer: `pointer_type<block_type<>>`
return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
Expand All @@ -1149,6 +1152,8 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
if ptr.type.is_const() or ptr.type.element_ty.is_const():
raise ValueError("Cannot store to a constant pointer")
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
Expand Down
30 changes: 18 additions & 12 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def annotation(self):
def is_constexpr(self):
return "constexpr" in self.annotation

@cached_property
def is_const(self):
return "const" in self.annotation and not self.is_constexpr

@property
def default(self):
return self._param.default
Expand All @@ -140,17 +144,22 @@ def __init__(self, value, param):
def name(self):
return self.param.name

def signature_key(self):
def mangled_type(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
const_str = "const " if self.param.is_const else ""
is_pointer = False
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"
return JITFunction._key_of(self.value)

if "Tensor" in annotation:
key = self.value.dtype
else:
key = JITFunction._key_of(self.value)
return JITFunction._type_of(key, self.param.is_const)

def specialization_key(self):
assert not self.param.do_not_specialize
Expand Down Expand Up @@ -255,7 +264,7 @@ def is_divisible_by_16(x):
# equal_to_1)

@staticmethod
def _type_of(key):
def _type_of(key, is_const=False):
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return "*i8"
Expand Down Expand Up @@ -288,7 +297,8 @@ def _type_of(key):
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
const_str = "k" if is_const else ""
return key if isinstance(key, str) else f"*{const_str}{tys[dtype_str]}"

def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
Expand Down Expand Up @@ -387,7 +397,7 @@ def run(self, *args, grid, warmup, **kwargs):
grid_2 = grid[2] if grid_size > 2 else 1
# compute cache key
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
sig_key = tuple(arg.mangled_type() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
key = (sig_key, constexpr_key, spec_key, options)
Expand All @@ -405,11 +415,7 @@ def run(self, *args, grid, warmup, **kwargs):
raise TypeError(f"Callable constexpr at index {i} is not supported")

# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(arg.signature_key())
for arg in args
if not arg.param.is_constexpr
}
signature = {arg.param.num: arg.mangled_type() for arg in args if not arg.param.is_constexpr}

if self._call_hook(key, signature, device, constants, options, configs):
return None
Expand Down
Loading