Skip to content

Commit

Permalink
[FRONTEND] Added interpreter mode (triton-lang#1573)
Browse files Browse the repository at this point in the history
Simple mechanism to run Triton kernels on PyTorch for debugging purpose
(upstream from Kernl).

Todo:
- random grid iteration
- support of atomic ops
- more unit tests
- cover new APIs?
  • Loading branch information
pommedeterresautee authored May 8, 2023
1 parent 50daf6c commit 2f38274
Show file tree
Hide file tree
Showing 11 changed files with 1,065 additions and 70 deletions.
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/debugger",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
69 changes: 69 additions & 0 deletions python/test/unit/debugger/test_debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import random

import torch

import triton
import triton.language as tl
from triton.debugger.debugger import program_ids_from_grid


def test_addition():

@triton.jit(interpret=True)
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)

a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b
output = torch.empty((128,), device="cuda")

def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)

add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)

assert torch.allclose(expected, output, atol=1e-2, rtol=0)


def test_program_ids_from_grid():
random.seed(123)
grid = (3, 4)
expected_combinations = 3 * 4
unique_combinations = set(program_ids_from_grid(grid))
assert len(unique_combinations) == expected_combinations

first_run = list(program_ids_from_grid(grid))
second_run = list(program_ids_from_grid(grid))
assert first_run != second_run


def test_atomic():
@triton.jit(interpret=True)
def atomic(
x_ptr,
):
pid = tl.program_id(axis=0)
tl.atomic_add(x_ptr + pid, 1)
t = tl.atomic_xchg(x_ptr + pid, 3)
t += 1 # 2
tl.atomic_cas(x_ptr + pid, 3, t) # match
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
nb_dim = 16
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")

atomic[(nb_dim, )](a)
assert torch.allclose(a, torch.full_like(a, 2))
124 changes: 62 additions & 62 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,68 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)


layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]


@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
@pytest.mark.parametrize("src_layout", layouts)
def test_reduce_2d(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')

z = np.zeros((1,)).astype('int32')

x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)

pgm = kernel[(1, 1, 1)](x_tri, z_tri)

z_ref = np.sum(x)

np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
# layouts = [
# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
# ]


# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
# @pytest.mark.parametrize("src_layout", layouts)
# def test_reduce_2d(M, N, src_layout, device='cuda'):
# ir = f"""
# #src = {src_layout}
# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
# %8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
# %11 = "tt.reduce"(%10) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %12 = "tt.reduce"(%11) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
# tt.return
# }}
# }}
# """
# import tempfile
# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
# f.write(ir)
# f.flush()
# kernel = triton.compile(f.name)
#
# rs = RandomState(17)
# x = rs.randint(0, 4, (M, N)).astype('int32')
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
#
# z = np.zeros((1,)).astype('int32')
#
# x_tri = torch.tensor(x, device=device)
# z_tri = torch.tensor(z, device=device)
#
# pgm = kernel[(1, 1, 1)](x_tri, z_tri)
#
# z_ref = np.sum(x)
#
# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


def test_generic_reduction(device='cuda'):
Expand Down
3 changes: 3 additions & 0 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from .debugger.debugger import program_ids_from_grid

from . import language
from . import testing

Expand All @@ -41,6 +43,7 @@
"runtime",
"TensorWrapper",
"testing",
"program_ids_from_grid",
]


Expand Down
Empty file.
9 changes: 9 additions & 0 deletions python/triton/debugger/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Tuple

import dataclasses


@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]
Loading

0 comments on commit 2f38274

Please sign in to comment.