Skip to content

Commit

Permalink
[FRONTEND] Fix @triton.jit(debug=True) (#5037)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 authored Nov 1, 2024
1 parent d0db12b commit 3ca2f49
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
26 changes: 16 additions & 10 deletions python/test/unit/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,33 @@
import triton.language as tl
import triton

@pytest.mark.parametrize('cond, opt_flag, env_var', [
(cond, opt_flag, env_var) for cond in [True, False] \
for opt_flag in [True, False] \
for env_var in [True, False]\
])

@pytest.mark.parametrize('cond', [True, False])
@pytest.mark.parametrize('opt_flag', [True, False, None])
@pytest.mark.parametrize('env_var', [True, False])
@pytest.mark.parametrize('jit_flag', [True, False])
@pytest.mark.forked
def test_device_assert(cond, opt_flag, env_var, device):
def test_device_assert(cond, opt_flag, env_var, jit_flag, device):
os.environ['TRITON_DEBUG'] = str(int(env_var))
torch.zeros([1], dtype=torch.int32, device=device)

@triton.jit
@triton.jit(debug=jit_flag)
def _kernel(COND: tl.constexpr):
tl.device_assert(COND, 'test')

if not cond and (opt_flag or env_var):
is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag)

kwargs = {}
if opt_flag is not None:
kwargs["debug"] = opt_flag

if not cond and is_debug:
with pytest.raises(RuntimeError):
_kernel[(1, )](cond, debug=opt_flag)
_kernel[(1, )](cond, **kwargs)
getattr(torch, device).synchronize()
return

_kernel[(1, )](cond, debug=opt_flag)
_kernel[(1, )](cond, **kwargs)
getattr(torch, device).synchronize()


Expand Down
1 change: 0 additions & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def builtin(fn: T) -> T:
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
print(kwargs)
raise ValueError("Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def create_binder(self, backend):
]

def run(self, *args, grid, warmup, **kwargs):
kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1"
kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"

# parse options
from ..compiler import make_backend
Expand Down Expand Up @@ -698,6 +698,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel = None
self.debug = debug
self.noinline = noinline

# TODO(jlebar): Remove uses of these fields outside this file, then
Expand Down

0 comments on commit 3ca2f49

Please sign in to comment.