diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b5c34cf0f04b..22095c66c67b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3232,6 +3232,67 @@ def kernel(out_ptr): assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['float32', 'float16']) def test_dot_without_load(dtype_str, device): diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index b439eb2f9cc9..e0ae6d97aba6 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -45,6 +45,8 @@ broadcast_to, cat, clamp, + const, + const_pointer_type, constexpr, debug_barrier, device_assert, @@ -142,6 +144,8 @@ "cat", "cdiv", "clamp", + "const", + "const_pointer_type", "constexpr", "cos", "cumprod", @@ -245,7 +249,12 @@ def str_to_ty(name): if name[0] == "*": - ty = str_to_ty(name[1:]) + name = name[1:] + if name[0] == "k": + name = name[1:] + ty = str_to_ty(name) + return const_pointer_type(ty) + ty = str_to_ty(name) return pointer_type(ty) tys = { "fp8e4nv": float8e4nv, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 66a6778f9f3f..db749aa13933 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -307,6 +307,10 @@ def is_block(): def is_ptr(): return False + @staticmethod + def is_const(): + return False + def __eq__(self, other: dtype): if not isinstance(other, dtype): return False @@ -419,6 +423,23 @@ def scalar(self): return self +class const_pointer_type(pointer_type): + + def __init__(self, element_ty: dtype, address_space: int = 1): + super().__init__(element_ty, address_space) + + def __str__(self): + return f'const_pointer<{self.element_ty}>' + + def is_const(self): + return True + + def __eq__(self, other) -> bool: + if not isinstance(other, const_pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + class block_type(dtype): def __init__(self, element_ty: dtype, shape: List): @@ -514,6 +535,17 @@ def to_ir(self, builder: ir.builder): # ----------------------- +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + class constexpr: """ This class is used to store a value that is known at compile-time. diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 5480ed9843fd..ac8e5af53201 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1121,6 +1121,9 @@ def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_ch cache = _str_to_store_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): # Store by a block pointer: `pointer_type>` return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) @@ -1147,6 +1150,8 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") element_ty = ptr.type.scalar.element_ty if element_ty is tl.float16 and op != 'add': raise ValueError("atomic_" + op + " does not support fp16") diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 1d21501b6e83..82fa19ba3506 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -117,6 +117,10 @@ def annotation(self): def is_constexpr(self): return "constexpr" in self.annotation + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + @property def default(self): return self._param.default @@ -140,17 +144,22 @@ def __init__(self, value, param): def name(self): return self.param.name - def signature_key(self): + def mangled_type(self): annotation = self.param.annotation - if "Tensor" in annotation: - return self.value.dtype + const_str = "const " if self.param.is_const else "" + is_pointer = False for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: width = annotation[annotation.find(ty1) + len(ty1):] if width and ty1 in annotation: return f"{ty2}{width}" if annotation == "bool": return "u1" - return JITFunction._key_of(self.value) + + if "Tensor" in annotation: + key = self.value.dtype + else: + key = JITFunction._key_of(self.value) + return JITFunction._type_of(key, self.param.is_const) def specialization_key(self): assert not self.param.do_not_specialize @@ -255,7 +264,7 @@ def is_divisible_by_16(x): # equal_to_1) @staticmethod - def _type_of(key): + def _type_of(key, is_const=False): # `None` is nullptr. Implicitly convert to *i8. if key is None: return "*i8" @@ -288,7 +297,8 @@ def _type_of(key): # reinterpret can create triton type for v in list(tys.values()): tys[v] = v - return key if isinstance(key, str) else f"*{tys[dtype_str]}" + const_str = "k" if is_const else "" + return key if isinstance(key, str) else f"*{const_str}{tys[dtype_str]}" def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) @@ -387,7 +397,7 @@ def run(self, *args, grid, warmup, **kwargs): grid_2 = grid[2] if grid_size > 2 else 1 # compute cache key args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)] - sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr) + sig_key = tuple(arg.mangled_type() for arg in args if not arg.param.is_constexpr) spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize) constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr) key = (sig_key, constexpr_key, spec_key, options) @@ -405,11 +415,7 @@ def run(self, *args, grid, warmup, **kwargs): raise TypeError(f"Callable constexpr at index {i} is not supported") # Build kernel signature -- doesn't include constexpr arguments. - signature = { - arg.param.num: self._type_of(arg.signature_key()) - for arg in args - if not arg.param.is_constexpr - } + signature = {arg.param.num: arg.mangled_type() for arg in args if not arg.param.is_constexpr} if self._call_hook(key, signature, device, constants, options, configs): return None