Skip to content

Commit

Permalink
[Mosaic GPU] Support narrower swizzles in copy and TMA tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649045134
  • Loading branch information
apaszke authored and jax authors committed Jul 3, 2024
1 parent c00ac4f commit ade76f0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
10 changes: 9 additions & 1 deletion jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def async_copy(
src_ref_ty = ir.MemRefType(src_ref.type)
dst_ref_ty = ir.MemRefType(dst_ref.type)
element_type = src_ref_ty.element_type
element_bytewidth = mgpu.bytewidth(element_type)
if element_type != dst_ref_ty.element_type:
raise ValueError(
f"Expected same element type, got {element_type} and"
Expand Down Expand Up @@ -397,10 +398,17 @@ def async_copy(
rank = len(slice_shape)
if rank > 5: # TODO: apaszke - Implement stride compression
raise ValueError("Async copies only support striding up to 5 dimensions")
if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth:
raise ValueError(
f"Async copies with {swizzle=} require last dimension of the slice to"
f" be exactly {swizzle} bytes"
f" ({swizzle // element_bytewidth} elements), but got"
f" {slice_shape[-1]}"
)
smem_ptr = utils.memref_ptr(smem_ref, memory_space=3)
if gmem_ref is src_ref:
assert barrier is not None # for pytype
slice_bytes = c(np.prod(slice_shape) * mgpu.bytewidth(element_type), i32)
slice_bytes = c(np.prod(slice_shape) * element_bytewidth, i32)
barrier_ptr = barrier.get_ptr()
with uniform_ctx():
if arrive:
Expand Down
61 changes: 27 additions & 34 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,21 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
def body(*idx):
dst_idx = idx
if swizzle is not None:
if swizzle != 128:
raise NotImplementedError("Only swizzle 128B implemented")
# TODO(apaszke): This can probably be cleaned up.
# But it works and it's test-only, so it doesn't matter much.
# After all, swizzle should just be an xor of row and linear idx,
# adjusted for the bytewidth.
assert swizzle.bit_count() == 1
bytes_per_element = bytewidth(src_ty.element_type)
elems_per_tile = 1024 // bytes_per_element
elems_per_row = elems_per_tile // 8
elems_per_group = 16 // bytes_per_element
linear_idx = c(0, index)
for stride, i in zip(dyn_strides, idx):
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
tile_offset = arith.remui(linear_idx, c(elems_per_tile, index))
linear_tile_start = arith.subi(linear_idx, tile_offset)
row = arith.divui(tile_offset, c(elems_per_row, index))
row_offset = arith.remui(tile_offset, c(elems_per_row, index))
src_group = arith.divui(row_offset, c(elems_per_group, index))
group_offset = arith.remui(row_offset, c(elems_per_group, index))
dst_group = arith.xori(src_group, row)
dst_linear_idx = mlir_sum([
linear_tile_start,
arith.muli(row, c(elems_per_row, index)),
arith.muli(dst_group, c(elems_per_group, index)),
group_offset,
])
# Swizzle pattern repeats every 128 bytes.
swizzle_src = arith.remui(
arith.divui(linear_idx, c(128 // bytes_per_element, index)),
c(swizzle // 16, index),
)
# Swizzle happens in groups of 16 bytes.
swizzle_shift = 4 - (bytes_per_element.bit_length() - 1)
dst_linear_idx = arith.xori(
linear_idx, arith.shli(swizzle_src, c(swizzle_shift, index))
)
dst_idx = [
arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index))
for stride, bound in zip(dyn_strides, shape)
Expand Down Expand Up @@ -714,16 +703,17 @@ def kernel(ctx, dst, tmp):
)()
np.testing.assert_array_equal(y, np.full_like(y, 3, dtype=np.int32))


class TMATest(TestCase):

@parameterized.product(
swizzle=(None, 128),
shape=((64, 64), (5, 64), (2, 3, 5, 64)),
swizzle=(None, 32, 64, 128),
shape=((64, None), (5, None), (2, 3, 5, None)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_load_basic(self, swizzle, shape, dtype):
if dtype == jnp.float32:
shape = (*shape[:-1], shape[-1] // 2)
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (*shape[:-1], minor_size)
i1 = ir.IntegerType.get_signless(1)
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
Expand All @@ -740,9 +730,11 @@ def kernel(ctx, src, dst, tmp):
dtype=(jnp.float16, jnp.float32),
)
def test_tma_load_tiled(self, swizzle, shape, dtype):
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
# swizzle 32 and 64.
i1 = ir.IntegerType.get_signless(1)
index = ir.IndexType.get()
tiling = (32, 128 // jnp.dtype(dtype).itemsize)
tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize)
tiled_shape = tile_shape(shape, tiling)[:len(shape)]
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
Expand Down Expand Up @@ -772,9 +764,10 @@ def kernel(ctx, src, dst, tmp):
dtype=(jnp.float16, jnp.float32),
)
def test_tma_squeeze_indexing(self, swizzle, dtype):
shape = (4, 5, 64)
if dtype == jnp.float32:
shape = (*shape[:-1], shape[-1] // 2)
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
# swizzle 32 and 64.
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (4, 5, minor_size)
def kernel(ctx, src, dst, tmp):
barrier = BarrierArray(1)[0]
for i in range(4):
Expand Down Expand Up @@ -810,13 +803,13 @@ def kernel(ctx, src, dst, tmp):
np.testing.assert_array_equal(y, x)

@parameterized.product(
swizzle=(None, 128),
shape=((64, 64), (5, 64), (2, 3, 5, 64)),
swizzle=(None, 32, 64, 128),
shape=((64, None), (5, None), (2, 3, 5, None)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_store(self, swizzle, shape, dtype):
if dtype == jnp.float32:
shape = (*shape[:-1], shape[-1] // 2)
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (*shape[:-1], minor_size)
def kernel(ctx, src, dst, tmp):
copy(src, tmp, swizzle=swizzle)
ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle)
Expand Down

0 comments on commit ade76f0

Please sign in to comment.