Skip to content

Commit

Permalink
Adding tl.const annotation to mark and validate that const tensors ar…
Browse files Browse the repository at this point in the history
…e not being stored to (#3360)

Introducing `tl.const` argument annotation for marking pointers to
constant memory. Tensors in memory accessed through such pointers cannot
be modified (you cannot call `store` with const pointer).
  • Loading branch information
pawelszczerbuk authored Mar 14, 2024
1 parent 0db57b3 commit d42ca11
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 13 deletions.
61 changes: 61 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3232,6 +3232,67 @@ def kernel(out_ptr):
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None


@triton.jit
def pass_const(a, b, choose_b):
if choose_b:
return b
else:
return a


@pytest.mark.parametrize("choose_const", [True, False])
@pytest.mark.parametrize("constexpr", [True, False])
@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"])
def test_const(device, choose_const, constexpr, mode):

@triton.jit(do_not_specialize=["choose_const"])
def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elems
val = tl.load(in_ptr + offsets, mask=mask)
LOSE_TAIL
tl.store(final_out + offsets, val, mask=mask)

@triton.jit
def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32,
BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elems
val = tl.load(in_ptr + offsets, mask=mask)
LOSE_TAIL
tl.store(final_out + offsets, val, mask=mask)

if mode == "direct":
if choose_const:
LOSE_TAIL = "final_out = c_out"
else:
LOSE_TAIL = "final_out = out"
elif mode == "call":
LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)"
elif mode == "ternary":
LOSE_TAIL = "final_out = c_out if choose_const else out"
elif mode == "if":
LOSE_TAIL = """
if choose_const:
final_out = c_out
else:
final_out = out
"""

SIZE = 128
input = torch.randn((SIZE, ), dtype=torch.float32, device=device)
output = torch.zeros((SIZE, ), dtype=torch.float32, device=device)
patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''})

expect_fail = (not constexpr and mode != "direct") or choose_const
if expect_fail:
with pytest.raises(triton.CompilationError) as exc_info:
patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE)
else:
patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE)
assert torch.all(input == output)


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str, device):
Expand Down
11 changes: 10 additions & 1 deletion python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
broadcast_to,
cat,
clamp,
const,
const_pointer_type,
constexpr,
debug_barrier,
device_assert,
Expand Down Expand Up @@ -142,6 +144,8 @@
"cat",
"cdiv",
"clamp",
"const",
"const_pointer_type",
"constexpr",
"cos",
"cumprod",
Expand Down Expand Up @@ -245,7 +249,12 @@

def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
name = name[1:]
if name[0] == "k":
name = name[1:]
ty = str_to_ty(name)
return const_pointer_type(ty)
ty = str_to_ty(name)
return pointer_type(ty)
tys = {
"fp8e4nv": float8e4nv,
Expand Down
32 changes: 32 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def is_block():
def is_ptr():
return False

@staticmethod
def is_const():
return False

def __eq__(self, other: dtype):
if not isinstance(other, dtype):
return False
Expand Down Expand Up @@ -419,6 +423,23 @@ def scalar(self):
return self


class const_pointer_type(pointer_type):

def __init__(self, element_ty: dtype, address_space: int = 1):
super().__init__(element_ty, address_space)

def __str__(self):
return f'const_pointer<{self.element_ty}>'

def is_const(self):
return True

def __eq__(self, other) -> bool:
if not isinstance(other, const_pointer_type):
return False
return self.element_ty == other.element_ty and self.address_space == other.address_space


class block_type(dtype):

def __init__(self, element_ty: dtype, shape: List):
Expand Down Expand Up @@ -514,6 +535,17 @@ def to_ir(self, builder: ir.builder):
# -----------------------


class const:
"""
This class is used as a type annotation to mark pointers to constant data.
The `store` function cannot be called with a pointer to const. Constness
is part of the pointer type and the usual Triton type consistency rules
apply. For example you cannot have a function that returns constant pointer
in one return statement and non-constant pointer in another.
"""
pass


class constexpr:
"""
This class is used to store a value that is known at compile-time.
Expand Down
5 changes: 5 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,9 @@ def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_ch
cache = _str_to_store_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)

if ptr.type.is_const() or ptr.type.scalar.is_const():
raise ValueError("Cannot store to a constant pointer")

if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# Store by a block pointer: `pointer_type<block_type<>>`
return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
Expand All @@ -1147,6 +1150,8 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
if ptr.type.is_const() or ptr.type.element_ty.is_const():
raise ValueError("Cannot store to a constant pointer")
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
Expand Down
30 changes: 18 additions & 12 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def annotation(self):
def is_constexpr(self):
return "constexpr" in self.annotation

@cached_property
def is_const(self):
return "const" in self.annotation and not self.is_constexpr

@property
def default(self):
return self._param.default
Expand All @@ -140,17 +144,22 @@ def __init__(self, value, param):
def name(self):
return self.param.name

def signature_key(self):
def mangled_type(self):
annotation = self.param.annotation
if "Tensor" in annotation:
return self.value.dtype
const_str = "const " if self.param.is_const else ""
is_pointer = False
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"
return JITFunction._key_of(self.value)

if "Tensor" in annotation:
key = self.value.dtype
else:
key = JITFunction._key_of(self.value)
return JITFunction._type_of(key, self.param.is_const)

def specialization_key(self):
assert not self.param.do_not_specialize
Expand Down Expand Up @@ -255,7 +264,7 @@ def is_divisible_by_16(x):
# equal_to_1)

@staticmethod
def _type_of(key):
def _type_of(key, is_const=False):
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return "*i8"
Expand Down Expand Up @@ -288,7 +297,8 @@ def _type_of(key):
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
const_str = "k" if is_const else ""
return key if isinstance(key, str) else f"*{const_str}{tys[dtype_str]}"

def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
Expand Down Expand Up @@ -387,7 +397,7 @@ def run(self, *args, grid, warmup, **kwargs):
grid_2 = grid[2] if grid_size > 2 else 1
# compute cache key
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
sig_key = tuple(arg.mangled_type() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
key = (sig_key, constexpr_key, spec_key, options)
Expand All @@ -405,11 +415,7 @@ def run(self, *args, grid, warmup, **kwargs):
raise TypeError(f"Callable constexpr at index {i} is not supported")

# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(arg.signature_key())
for arg in args
if not arg.param.is_constexpr
}
signature = {arg.param.num: arg.mangled_type() for arg in args if not arg.param.is_constexpr}

if self._call_hook(key, signature, device, constants, options, configs):
return None
Expand Down

0 comments on commit d42ca11

Please sign in to comment.