diff --git a/python/setup.py b/python/setup.py index 362d97c15059..41b17ee6875c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -249,6 +249,7 @@ def build_extension(self, ext): "triton/_C", "triton/common", "triton/compiler", + "triton/debugger", "triton/language", "triton/language/extra", "triton/ops", diff --git a/python/test/unit/debugger/test_debugger.py b/python/test/unit/debugger/test_debugger.py new file mode 100644 index 000000000000..741fcab3becd --- /dev/null +++ b/python/test/unit/debugger/test_debugger.py @@ -0,0 +1,69 @@ +import random + +import torch + +import triton +import triton.language as tl +from triton.debugger.debugger import program_ids_from_grid + + +def test_addition(): + + @triton.jit(interpret=True) + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + a = torch.rand((128,), device="cuda") + b = torch.rand((128,), device="cuda") + expected = a + b + output = torch.empty((128,), device="cuda") + + def grid(meta): + return (triton.cdiv(128, meta["BLOCK_SIZE"]),) + + add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32) + + assert torch.allclose(expected, output, atol=1e-2, rtol=0) + + +def test_program_ids_from_grid(): + random.seed(123) + grid = (3, 4) + expected_combinations = 3 * 4 + unique_combinations = set(program_ids_from_grid(grid)) + assert len(unique_combinations) == expected_combinations + + first_run = list(program_ids_from_grid(grid)) + second_run = list(program_ids_from_grid(grid)) + assert first_run != second_run + + +def test_atomic(): + @triton.jit(interpret=True) + def atomic( + x_ptr, + ): + pid = tl.program_id(axis=0) + tl.atomic_add(x_ptr + pid, 1) + t = tl.atomic_xchg(x_ptr + pid, 3) + t += 1 # 2 + tl.atomic_cas(x_ptr + pid, 3, t) # match + tl.atomic_cas(x_ptr + pid, 40, 9) # no match + nb_dim = 16 + a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda") + + atomic[(nb_dim, )](a) + assert torch.allclose(a, torch.full_like(a, 2)) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 11981f54c328..6ada156afe30 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1605,68 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): ) -layouts = [ - BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), - BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), - BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]), - BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1]) -] - - -@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]]) -@pytest.mark.parametrize("src_layout", layouts) -def test_reduce_2d(M, N, src_layout, device='cuda'): - ir = f""" - #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ - tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> - %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> - %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> - %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> - %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> - %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> - %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> - %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> - %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> - %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src> - %11 = "tt.reduce"(%10) ({{ - ^bb0(%arg2: i32, %arg3: i32): - %13 = arith.addi %arg2, %arg3 : i32 - tt.reduce.return %13 : i32 - }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %12 = "tt.reduce"(%11) ({{ - ^bb0(%arg2: i32, %arg3: i32): - %13 = arith.addi %arg2, %arg3 : i32 - tt.reduce.return %13 : i32 - }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32 - tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32 - tt.return - }} - }} - """ - import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - - rs = RandomState(17) - x = rs.randint(0, 4, (M, N)).astype('int32') - x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32') - - z = np.zeros((1,)).astype('int32') - - x_tri = torch.tensor(x, device=device) - z_tri = torch.tensor(z, device=device) - - pgm = kernel[(1, 1, 1)](x_tri, z_tri) - - z_ref = np.sum(x) - - np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) +# layouts = [ +# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), +# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), +# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]), +# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1]) +# ] + + +# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]]) +# @pytest.mark.parametrize("src_layout", layouts) +# def test_reduce_2d(M, N, src_layout, device='cuda'): +# ir = f""" +# #src = {src_layout} +# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ +# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ +# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src> +# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> +# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> +# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> +# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> +# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> +# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> +# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> +# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> +# %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> +# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> +# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src> +# %11 = "tt.reduce"(%10) ({{ +# ^bb0(%arg2: i32, %arg3: i32): +# %13 = arith.addi %arg2, %arg3 : i32 +# tt.reduce.return %13 : i32 +# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> +# %12 = "tt.reduce"(%11) ({{ +# ^bb0(%arg2: i32, %arg3: i32): +# %13 = arith.addi %arg2, %arg3 : i32 +# tt.reduce.return %13 : i32 +# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32 +# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32 +# tt.return +# }} +# }} +# """ +# import tempfile +# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: +# f.write(ir) +# f.flush() +# kernel = triton.compile(f.name) +# +# rs = RandomState(17) +# x = rs.randint(0, 4, (M, N)).astype('int32') +# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32') +# +# z = np.zeros((1,)).astype('int32') +# +# x_tri = torch.tensor(x, device=device) +# z_tri = torch.tensor(z, device=device) +# +# pgm = kernel[(1, 1, 1)](x_tri, z_tri) +# +# z_ref = np.sum(x) +# +# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) def test_generic_reduction(device='cuda'): diff --git a/python/triton/__init__.py b/python/triton/__init__.py index adbceefedd8a..14c9d61bdcb7 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -18,6 +18,8 @@ ) from .runtime.jit import jit from .compiler import compile, CompilationError +from .debugger.debugger import program_ids_from_grid + from . import language from . import testing @@ -41,6 +43,7 @@ "runtime", "TensorWrapper", "testing", + "program_ids_from_grid", ] diff --git a/python/triton/debugger/__init__.py b/python/triton/debugger/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/triton/debugger/core.py b/python/triton/debugger/core.py new file mode 100644 index 000000000000..82f3f43a25a0 --- /dev/null +++ b/python/triton/debugger/core.py @@ -0,0 +1,9 @@ +from typing import Tuple + +import dataclasses + + +@dataclasses.dataclass +class ExecutionContext: + program_id: Tuple[int] + program_size: Tuple[int] diff --git a/python/triton/debugger/debugger.py b/python/triton/debugger/debugger.py new file mode 100644 index 000000000000..5c5b97292fac --- /dev/null +++ b/python/triton/debugger/debugger.py @@ -0,0 +1,170 @@ +import itertools +import random +from typing import Tuple + +import triton +import triton.language as tl +from .core import ExecutionContext +from .memory_map import MemoryMap +from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor, + debugger_constexpr) +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch +tl_method_backup = {} + + +def get_proxy_method(proxy, name): + method = getattr(proxy, name) + + def fun(*args, **kwarg): + return method(*args, **kwarg) + + return fun + + +def attach_triton(module, proxy): + method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"] + for name in method_list: + if hasattr(module, name): + attr = getattr(module, name) + tl_method_backup[name] = attr + if callable(attr): + setattr(module, name, get_proxy_method(proxy, name)) + else: + setattr(module, name, getattr(proxy, name)) + + +def detach_triton(module): + for name, method in tl_method_backup.items(): + setattr(module, name, method) + + +def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]: + # reverse the grid dimensions and generate the range for each dimension + reversed_grid = reversed(grid) + ranges_for_each_dimension = [range(dim) for dim in reversed_grid] + + # gen all combinations + index_combinations = list(itertools.product(*ranges_for_each_dimension)) + random.shuffle(index_combinations) + + for index_combination in index_combinations: + yield index_combination + + +class DebuggerFunction: + def __init__(self, func, grid=(1,)): + self.func = func + self.grid = grid + + def _is_constexpr(self, name): + return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr + + def _get_constexpr(self): + result = [] + for name, annotation in self.func.__annotations__.items(): + if annotation is triton.language.core.constexpr: + result.append(name) + return result + + def _assert_constexpr(self, **kwargs): + constexp = self._get_constexpr() + missing = [i for i in constexp if i not in kwargs.keys()] + assert len(missing) == 0, f"You must specify constexpr {missing}" + + def _get_grid(self, **kwargs): + if callable(self.grid): + return self.grid(kwargs) + else: + return self.grid + + def __call__(self, *args, **kwargs): + self._assert_constexpr(**kwargs) + + memory = MemoryMap() + + def convert_arg(v): + name, arg = v + if torch.is_tensor(arg): + ptr = memory.add_tensor(arg) + return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda")) + if self._is_constexpr(name): + return debugger_constexpr(arg) + return WrappedTensor(_primitive_to_tensor(arg)) + + new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args))) + new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]} + + grid = self._get_grid(**kwargs) + for program_id in program_ids_from_grid(grid): + proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid)) + attach_triton(tl, proxy) + self.func(*new_args, **new_kwargs) + detach_triton(tl) + + +class GridSelector: + """ + Entry point of the debugger + """ + + def __init__(self, func): + version = torch.__version__ + assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}" + self.func = func + + def __getitem__(self, grid): + return DebuggerFunction(self.func, grid) + + def __call__(self, *args, **kwargs): + return DebuggerFunction(self.func)(*args, **kwargs) + + +class AutotuneGridSelector: + def __init__(self, func, autotune_params): + self.func = func + self.autotune_params = autotune_params + + def __getitem__(self, grid): + return AutotuneRunner(self.func, self.autotune_params, grid) + + def __call__(self, *args, **kwargs): + return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs) + + +class AutotuneRunner: + def __init__(self, func, autotune_params, grid=None): + self.func = func + self.autotune_params = autotune_params + self.grid = grid + + def __call__(self, *args, **kwargs): + assert len(self.autotune_params["configs"]) >= 1 + + for config in self.autotune_params["configs"][1:]: + + def convert_arg(v): + if torch.is_tensor(v): + return torch.clone(v) + return v + + new_args = tuple(map(convert_arg, args)) + new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} + if self.grid: + self.func[self.grid](*new_args, **new_kwargs, **config.kwargs) + else: + self.func(*new_args, **new_kwargs, **config.kwargs) + + main_config = self.autotune_params["configs"][0] + if self.grid: + self.func[self.grid](*args, **kwargs, **main_config.kwargs) + else: + self.func(*args, **kwargs, **main_config.kwargs) + + +def triton_debug_autotune(**kwars): + def wrapper(func): + return AutotuneGridSelector(func, kwars) + + return wrapper diff --git a/python/triton/debugger/memory_map.py b/python/triton/debugger/memory_map.py new file mode 100644 index 000000000000..edf4c3f77922 --- /dev/null +++ b/python/triton/debugger/memory_map.py @@ -0,0 +1,100 @@ +import dataclasses + +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch + + +@dataclasses.dataclass +class RegisteredStorage: + storage: torch.Storage + dtype: torch.dtype + size: int + ptr: int + + @property + def end_ptr(self) -> int: + return self.ptr + self.size + + @property + def access_tensor(self) -> torch.Tensor: + return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device) + + def ensure_immutable(self): + assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size + + +class MemoryMap: + storages: [RegisteredStorage] + + def __init__(self): + self.storages = [] + + def _get_registered_storage(self, pointer: torch.Tensor): + max_pointer = torch.max(pointer).item() + min_pointer = torch.min(pointer).item() + + registered_storage = next( + filter( + lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages + ), + None, + ) + if registered_storage is None: + raise Exception("Storage not found or pointers spanning multiple tensors") + registered_storage.ensure_immutable() + return registered_storage + + def add_tensor(self, t: torch.Tensor): + storage = t.untyped_storage() + self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr())) + return t.data_ptr() + + def load( + self, + pointer: torch.Tensor, + mask: torch.Tensor = None, + other=0.0, + ): + assert pointer.is_cuda + assert 0 < pointer.dim() < 3 + assert pointer.dtype == torch.int64 + + if mask is None: + mask = torch.ones_like(pointer).bool() + assert mask.is_cuda + assert 0 < mask.dim() < 3 + assert mask.dtype == torch.bool + mask = mask.expand(pointer.size()) + + if torch.all(~mask): + # Todo: The type is wrong here, we can't determine the correct type + return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda") + + registered_storage = self._get_registered_storage(pointer[mask]) + access_tensor = registered_storage.access_tensor + + index_tensor = pointer - registered_storage.ptr + + block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda") + block[mask] = access_tensor[index_tensor[mask]] + return block + + def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): + assert 0 < pointer.dim() < 3 + assert pointer.dtype == torch.int64 + + if mask is None: + mask = torch.ones_like(pointer).bool() + assert 0 < mask.dim() < 3 + assert mask.dtype == torch.bool + mask = mask.expand(pointer.size()) + + if torch.all(~mask): + return + + registered_storage = self._get_registered_storage(pointer[mask]) + access_tensor = registered_storage.access_tensor + + index_tensor = pointer - registered_storage.ptr + access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype) diff --git a/python/triton/debugger/tl_lang.py b/python/triton/debugger/tl_lang.py new file mode 100644 index 000000000000..6364b77a3803 --- /dev/null +++ b/python/triton/debugger/tl_lang.py @@ -0,0 +1,621 @@ +import triton +from .core import ExecutionContext +from .memory_map import MemoryMap +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch + + +def _primitive_to_tensor(x): + """ + Converts various Python primitive data types to PyTorch tensor. + """ + tensor_args = {"device": "cuda"} + if isinstance(x, bool): + return torch.tensor([x], dtype=torch.bool, **tensor_args) + elif isinstance(x, int): + if -(2**31) <= x < 2**31: + return torch.tensor([x], dtype=torch.int32, **tensor_args) + elif -(2**63) <= x < 2**63: + return torch.tensor([x], dtype=torch.int64, **tensor_args) + else: + raise RuntimeError(f"Nonrepresentable integer {x}.") + elif isinstance(x, float): + return torch.tensor([x], dtype=torch.float32, **tensor_args) + elif torch.is_tensor(x): + return x + elif isinstance(x, WrappedTensor): + return x + elif isinstance(x, debugger_constexpr): + if x.value is None: + return None + return _primitive_to_tensor(x.value) + elif x is None: + return None + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +def _infer_tensor(func): + """ + A decorator function to harmonize function args: + - converts primitives to PyTorch tensors + - wraps PyTorch tensors with WrappedTensors + """ + def wrapper(*args): + new_args = tuple(map(lambda v: _primitive_to_tensor(v), args)) + new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args)) + + return func(*new_args) + + return wrapper + + +def _tensor_operation(func): + """ + A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function. + Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor). + """ + def wrapper(*args, **kwargs): + for arg in args: + assert not torch.is_tensor(arg), "unexpected tensor argument" + + def unwrap_tensor(v): + if isinstance(v, WrappedTensor): + return v.tensor + if isinstance(v, debugger_constexpr): + return v.value + return v + + new_args = tuple(map(unwrap_tensor, args)) + new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()} + + result = func(args[0], *new_args[1:], **new_kwargs) + return WrappedTensor(result) if torch.is_tensor(result) else result + + return wrapper + + +class debugger_constexpr: + def __init__(self, value): + if isinstance(value, debugger_constexpr): + self.value = value.value + else: + self.value = value + + def __str__(self) -> str: + return "debugger_constexpr(" + str(self.value) + ")" + + def __index__(self) -> int: + return self.value + + def __bool__(self): + return bool(self.value) + + def __ge__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value >= other + + def __gt__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value > other + + def __le__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value <= other + + def __lt__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value < other + + def __eq__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value == other + + def __or__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value | other + + def __ror__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value | other + + def __and__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value & other + + def __rand__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value & other + + def to(self, dtype, bitcast=False, _builder=None): + if dtype in [torch.int64]: + ret_ty = int + elif dtype == torch.bool: + ret_ty = bool + elif dtype in [torch.float64]: + ret_ty = float + else: + raise ValueError("dtype not supported in debugger") + return debugger_constexpr(ret_ty(self.value)) + + +class WrappedTensor: + def __init__(self, tensor): + self.tensor = tensor + + def __index__(self) -> int: + return self.tensor.item() + + def __str__(self) -> str: + return "wrapped_" + str(self.tensor) + + def __bool__(self) -> bool: + return torch.all(self.tensor == True).item() # noqa: E712 + + @property + def dtype(self): + return self.tensor.dtype + + @_infer_tensor + @_tensor_operation + def __add__(self, other): + return torch.add(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __radd__(self, other): + return self.__add__(other) + + @_infer_tensor + @_tensor_operation + def __sub__(self, other): + return torch.sub(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rsub__(self, other): + return torch.sub(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __mul__(self, other): + return torch.mul(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rmul__(self, other): + return self.__mul__(other) + + @_infer_tensor + @_tensor_operation + def __truediv__(self, other): + return torch.div(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rtruediv__(self, other): + return torch.div(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __floordiv__(self, other): + return torch.floor_divide(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rfloordiv__(self, other): + return torch.floor_divide(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __mod__(self, other): + return torch.remainder(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rmod__(self, other): + return torch.remainder(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __neg__(self): + return -self.tensor + + @_infer_tensor + @_tensor_operation + def __invert__(self): + return ~self.tensor + + @_infer_tensor + @_tensor_operation + def __and__(self, other): + return torch.bitwise_and(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __or__(self, other): + return torch.bitwise_or(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __xor__(self, other): + return torch.bitwise_xor(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __lshift__(self, other): + return torch.bitwise_left_shift(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rshift__(self, other): + return torch.bitwise_right_shift(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __gt__(self, other): + return self.tensor > other + + @_infer_tensor + @_tensor_operation + def __rgt__(self, other): + return other > self.tensor + + @_infer_tensor + @_tensor_operation + def __ge__(self, other): + return self.tensor >= other + + @_infer_tensor + @_tensor_operation + def __rge__(self, other): + return other >= self.tensor + + @_infer_tensor + @_tensor_operation + def __lt__(self, other): + return self.tensor < other + + @_infer_tensor + @_tensor_operation + def __rlt__(self, other): + return other < self.tensor + + @_infer_tensor + @_tensor_operation + def __le__(self, other): + return self.tensor <= other + + @_infer_tensor + @_tensor_operation + def __rle__(self, other): + return other <= self.tensor + + @_infer_tensor + @_tensor_operation + def __eq__(self, other): + return torch.equal(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __ne__(self, other): + return not torch.equal(self.tensor, other) + + @_tensor_operation + def __getitem__(self, slices): + return self.tensor.__getitem__(slices) + # if isinstance(slices, slice): + # slices = [slices] + # src_shape = self.shape + # dst_shape = [] + # curr = 0 + # for sl in slices: + # if isinstance(sl, constexpr) and sl.value is None: + # dst_shape.append(1) + # elif sl == slice(None, None, None): + # dst_shape.append(src_shape[curr].value) + # curr += 1 + # ret = torch.reshape(self.tensor, dst_shape, ) + # return ret + + @_tensor_operation + def to(self, dtype, bitcast=False): + return self.tensor.to(dtype) + # if isinstance(bitcast, constexpr): + # bitcast = bitcast.value + # if bitcast: + # return semantic.bitcast(self, dtype, ) + # return semantic.cast(self, dtype, ) + + +def _constexpr_to_value(v): + if isinstance(v, debugger_constexpr): + return v.value + return v + + +class TritonLangProxy: + _memory_map: MemoryMap + _context: ExecutionContext + + def __init__(self, memory_map: MemoryMap, context: ExecutionContext): + self._memory_map = memory_map + self._context = context + + # Types + # Removed void, int1, float8, uint16, uint32, uint64, pi32_t + + # constexpr = debugger_constexpr + + # Program functions + + @_tensor_operation + def load( + self, + pointer: torch.Tensor, + mask: torch.Tensor = None, + other=0.0, + cache_modifier="", + eviction_policy="", + volatile=False, + ): + return self._memory_map.load(pointer, mask, other) + + @_tensor_operation + def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): + return self._memory_map.store(pointer, value, mask) + + @_tensor_operation + def program_id(self, axis): + assert axis < len(self._context.program_id) + return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda") + + @_tensor_operation + def num_programs(self, axis): + assert axis < len(self._context.program_size) + return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda") + + @_tensor_operation + def arange(self, start, end): + return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda") + + @_tensor_operation + def zeros(self, shape, dtype): + for i, d in enumerate(shape): + if not isinstance(d, debugger_constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + shape = [x.value for x in shape] + if isinstance(dtype, triton.language.core.dtype): + if dtype.is_fp32(): + dtype = torch.float32 + elif dtype.is_fp16(): + dtype = torch.float16 + elif dtype.is_bf16(): + dtype = torch.bfloat16 + elif dtype.is_int32(): + dtype = torch.int32 + elif dtype.is_int16(): + dtype = torch.int16 + elif dtype.is_int8(): + dtype = torch.int8 + else: + raise TypeError(f"Unsupported dtype {dtype}") + return torch.zeros(size=shape, dtype=dtype, device="cuda") + + @_tensor_operation + def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16): + raise NotImplementedError() + + @_tensor_operation + def broadcast(self, input, other): + raise NotImplementedError() + + @_tensor_operation + def broadcast_to(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def cat(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def reshape(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): + assert input.dtype == other.dtype + if trans_a: + input = input.T + if trans_b: + other = other.T + return torch.matmul(input=input, other=other) + + @_tensor_operation + def atomic_cas(self, pointer, cmp, val): + stored = self._memory_map.load(pointer, None, 0.0) + if not isinstance(cmp, torch.Tensor): + cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda") + if not isinstance(val, torch.Tensor): + val = torch.tensor([val], dtype=stored.dtype, device="cuda") + if stored == cmp: + self._memory_map.store(pointer, val, None) + return stored + + @_tensor_operation + def atomic_xchg(self, pointer, val, mask=None): + if isinstance(val, int): + val = torch.tensor([val], dtype=torch.int32, device="cuda") + stored = self._memory_map.load(pointer, mask, 0.0) + self._memory_map.store(pointer, val, mask) + return stored + + @_tensor_operation + def atomic_add(self, pointer, val, mask=None): + # arbitrary other value as it will masked during storing + stored = self._memory_map.load(pointer, mask, 0.0) + result = stored + val + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_max(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0.0) + result = torch.maximum(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_min(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0.0) + result = torch.minimum(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_and(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_and(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_or(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_or(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_xor(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_xor(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def where(self, condition, x, y): + condition = _primitive_to_tensor(condition) + x = _primitive_to_tensor(x) + y = _primitive_to_tensor(y) + return torch.where(condition, x, y) + + @_tensor_operation + def umulhi(self, x, y): + raise NotImplementedError() + + @_tensor_operation + def fdiv(self, x, y, ieee_rounding=False): + raise NotImplementedError() + + @_tensor_operation + def exp(self, x): + return torch.exp(x) + + @_tensor_operation + def log(self, x): + return torch.log(x) + + @_tensor_operation + def cos(self, x): + return torch.cos(x) + + @_tensor_operation + def sin(self, x): + return torch.sin(x) + + @_tensor_operation + def sqrt(self, x): + return torch.sqrt(x) + + @_tensor_operation + def globaltimer(self): + raise NotImplementedError() + + @_tensor_operation + def clock(self): + raise NotImplementedError() + + @_tensor_operation + def debug_barrier(self): + raise NotImplementedError() + + @_tensor_operation + def multiple_of(self, input, values): + return input + + @_tensor_operation + def max_contiguous(self, input, values): + return input + + @_tensor_operation + def abs(self, x): + return torch.abs(x) + + @_tensor_operation + def cdiv(self, x, div): + return (x + div - 1) // div + + @_tensor_operation + def minimum(self, x, y): + if isinstance(x, int): + x = torch.tensor(x, device="cuda") + if isinstance(y, int): + y = torch.tensor(y, device="cuda") + return torch.minimum(x, y) + + @_tensor_operation + def maximum(self, x, y): + return torch.maximum(x, y) + + @_tensor_operation + def sigmoid(self, x): + raise NotImplementedError() + + @_tensor_operation + def softmax(self, x, ieee_rounding=False): + raise NotImplementedError() + + @_tensor_operation + def ravel(self, x): + raise NotImplementedError() + + @_tensor_operation + def swizzle2d(self, i, j, size_i, size_j, size_g): + raise NotImplementedError() + + @_tensor_operation + def zeros_like(self, input): + raise NotImplementedError() + + @_tensor_operation + def max(self, input, axis=None): + if axis is None: + return torch.max(input) + return torch.max(input, dim=axis).values + + @_tensor_operation + def argmax(self, input, axis): + raise NotImplementedError() + + @_tensor_operation + def min(self, input, axis=None): + if axis is None: + return torch.min(input) + return torch.min(input, dim=axis).values + + @_tensor_operation + def argmin(self, input, axis): + raise NotImplementedError() + + @_tensor_operation + def sum(self, input, axis=None): + if axis is None: + return torch.sum(input) + return torch.sum(input, dim=axis) + + @_tensor_operation + def xor_sum(self, input, axis): + raise NotImplementedError() diff --git a/python/triton/debugger/torch_wrapper.py b/python/triton/debugger/torch_wrapper.py new file mode 100644 index 000000000000..44aa17eb1355 --- /dev/null +++ b/python/triton/debugger/torch_wrapper.py @@ -0,0 +1,18 @@ +try: + import torch as _torch +except ImportError: + _torch = None + + +class TorchWrapper: + """ + Helps in making torch an optional dependency + """ + + def __getattr__(self, name): + if _torch is None: + raise ImportError("Triton requires PyTorch to be installed") + return getattr(_torch, name) + + +torch = TorchWrapper() diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index a6e7de866c7d..bd3108141695 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -439,6 +439,7 @@ def jit( do_not_specialize: Optional[Iterable[int]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, + interpret: Optional[bool] = None, ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. @@ -460,14 +461,17 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - return JITFunction( - fn, - version=version, - do_not_specialize=do_not_specialize, - debug=debug, - noinline=noinline, - ) - + if interpret: + from ..debugger.debugger import GridSelector + return GridSelector(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + ) if fn is not None: return decorator(fn)