Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PyTorch debugger #1573

Merged
merged 31 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f5d3b80
feat: setup interpreter
pommedeterresautee Apr 24, 2023
9c94995
feat: random grid iteration
pommedeterresautee Apr 25, 2023
57c1955
Merge branch 'main' into feat/interpreter
pommedeterresautee Apr 25, 2023
7f4803e
feat: add atomic ops
pommedeterresautee Apr 27, 2023
9bf38ae
merge main
pommedeterresautee May 2, 2023
42cc175
fix debugger import issue
pommedeterresautee May 2, 2023
0bf7104
fix: limit module exports
pommedeterresautee May 2, 2023
11634b9
fix: export module in setup.py
pommedeterresautee May 2, 2023
c76f42f
fix: tuple -> Tuple (old Python support)
pommedeterresautee May 2, 2023
41f9519
fix: remove torch dependency
pommedeterresautee May 2, 2023
86aced7
fix: harmonize error message
pommedeterresautee May 2, 2023
90c1331
Merge branch 'main' into feat/interpreter
pommedeterresautee May 3, 2023
41488a8
Merge branch 'main' into feat/interpreter
pommedeterresautee May 3, 2023
d59ba19
fix: simplify mechanism to make torch optional
pommedeterresautee May 3, 2023
aa1f527
Merge branch 'main' into feat/interpreter
pommedeterresautee May 4, 2023
dce5db2
fix: check torch version
pommedeterresautee May 4, 2023
86d3e6c
Merge branch 'main' into feat/interpreter
pommedeterresautee May 8, 2023
047b251
feat: setup interpreter
pommedeterresautee Apr 24, 2023
bb3e4c0
feat: random grid iteration
pommedeterresautee Apr 25, 2023
1c2f5ad
feat: add atomic ops
pommedeterresautee Apr 27, 2023
6e89e6b
fix debugger import issue
pommedeterresautee May 2, 2023
2839e34
fix: limit module exports
pommedeterresautee May 2, 2023
d6c93a8
fix: export module in setup.py
pommedeterresautee May 2, 2023
ee877ba
fix: tuple -> Tuple (old Python support)
pommedeterresautee May 2, 2023
3f51fb6
fix: remove torch dependency
pommedeterresautee May 2, 2023
e492b46
fix: harmonize error message
pommedeterresautee May 2, 2023
a501314
fix: simplify mechanism to make torch optional
pommedeterresautee May 3, 2023
3ad8c43
fix: check torch version
pommedeterresautee May 4, 2023
e33280e
fix: fix assert
pommedeterresautee May 8, 2023
2772e39
Merge remote-tracking branch 'origin/feat/interpreter' into feat/inte…
pommedeterresautee May 8, 2023
7154fcb
fix: disable test_reduce_2d
pommedeterresautee May 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/debugger",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
69 changes: 69 additions & 0 deletions python/test/unit/debugger/test_debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import random

import torch

import triton
import triton.language as tl
from triton.debugger.debugger import program_ids_from_grid


def test_addition():

@triton.jit(interpret=True)
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)

a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b
output = torch.empty((128,), device="cuda")

def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)

add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)

assert torch.allclose(expected, output, atol=1e-2, rtol=0)


def test_program_ids_from_grid():
random.seed(123)
grid = (3, 4)
expected_combinations = 3 * 4
unique_combinations = set(program_ids_from_grid(grid))
assert len(unique_combinations) == expected_combinations

first_run = list(program_ids_from_grid(grid))
second_run = list(program_ids_from_grid(grid))
assert first_run != second_run


def test_atomic():
@triton.jit(interpret=True)
def atomic(
x_ptr,
):
pid = tl.program_id(axis=0)
tl.atomic_add(x_ptr + pid, 1)
t = tl.atomic_xchg(x_ptr + pid, 3)
t += 1 # 2
tl.atomic_cas(x_ptr + pid, 3, t) # match
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
nb_dim = 16
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")

atomic[(nb_dim, )](a)
assert torch.allclose(a, torch.full_like(a, 2))
124 changes: 62 additions & 62 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,68 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)


layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]


@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
@pytest.mark.parametrize("src_layout", layouts)
def test_reduce_2d(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')

z = np.zeros((1,)).astype('int32')

x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)

pgm = kernel[(1, 1, 1)](x_tri, z_tri)

z_ref = np.sum(x)

np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
# layouts = [
# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
# ]


# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
# @pytest.mark.parametrize("src_layout", layouts)
# def test_reduce_2d(M, N, src_layout, device='cuda'):
# ir = f"""
# #src = {src_layout}
# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
# %8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
# %11 = "tt.reduce"(%10) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %12 = "tt.reduce"(%11) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
# tt.return
# }}
# }}
# """
# import tempfile
# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
# f.write(ir)
# f.flush()
# kernel = triton.compile(f.name)
#
# rs = RandomState(17)
# x = rs.randint(0, 4, (M, N)).astype('int32')
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
#
# z = np.zeros((1,)).astype('int32')
#
# x_tri = torch.tensor(x, device=device)
# z_tri = torch.tensor(z, device=device)
#
# pgm = kernel[(1, 1, 1)](x_tri, z_tri)
#
# z_ref = np.sum(x)
#
# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


def test_generic_reduction(device='cuda'):
Expand Down
3 changes: 3 additions & 0 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from .debugger.debugger import program_ids_from_grid

from . import language
from . import testing

Expand All @@ -41,6 +43,7 @@
"runtime",
"TensorWrapper",
"testing",
"program_ids_from_grid",
]


Expand Down
Empty file.
9 changes: 9 additions & 0 deletions python/triton/debugger/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Tuple

import dataclasses


@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]
Loading