Skip to content

Commit

Permalink
[FRONTEND] Mangle signed and unsigned integer types differently (trit…
Browse files Browse the repository at this point in the history
…on-lang#1340)

This is cherry-picked from triton-lang#1305

If you call a `JITFunction` twice in the same kernel, first with `int32`
then with `uint32`, the second call will treat the unsigned value as
signed. This passes through MLIR without error because MLIR uses the
same types for both, but different operation calls will be generated so
you may silently get the wrong result.
  • Loading branch information
peterbell10 authored Mar 15, 2023
1 parent fdec9c1 commit a23c4ef
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
37 changes: 37 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,43 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)


def test_unsigned_name_mangling(device='cuda'):
# Test that uint32 and int32 are mangled differently by the compiler
SIZE = 128
# define the kernel / launch-grid

@triton.jit
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
out1 = tl.abs(x) # uint32 -> nop
out2 = tl.abs(-y) # int32 -> should have an effect
tl.store(O1 + off, out1)
tl.store(O2 + off, out2)

dtype_x = 'uint32'
dtype_y = 'int32'
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
# reference result
expect = (np.abs(x), np.abs(-y))
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
actual = tuple(
to_triton(np.empty_like(e), device=device)
for e in expect
)
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)

# Bitwise op, so expect exact equality
assert (expect[0] == to_numpy(actual[0])).all()
assert (expect[1] == to_numpy(actual[1])).all()


# ---------------
# test bitwise ops
# ---------------
Expand Down
4 changes: 3 additions & 1 deletion python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
SIGNED = triton.language.dtype.SIGNEDNESS.SIGNED
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
return prefix + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
Expand Down

0 comments on commit a23c4ef

Please sign in to comment.