diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 705b5419b6c2..f1b9209274ad 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -329,6 +329,7 @@ def async_copy( src_ref_ty = ir.MemRefType(src_ref.type) dst_ref_ty = ir.MemRefType(dst_ref.type) element_type = src_ref_ty.element_type + element_bytewidth = mgpu.bytewidth(element_type) if element_type != dst_ref_ty.element_type: raise ValueError( f"Expected same element type, got {element_type} and" @@ -397,10 +398,17 @@ def async_copy( rank = len(slice_shape) if rank > 5: # TODO: apaszke - Implement stride compression raise ValueError("Async copies only support striding up to 5 dimensions") + if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: + raise ValueError( + f"Async copies with {swizzle=} require last dimension of the slice to" + f" be exactly {swizzle} bytes" + f" ({swizzle // element_bytewidth} elements), but got" + f" {slice_shape[-1]}" + ) smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) if gmem_ref is src_ref: assert barrier is not None # for pytype - slice_bytes = c(np.prod(slice_shape) * mgpu.bytewidth(element_type), i32) + slice_bytes = c(np.prod(slice_shape) * element_bytewidth, i32) barrier_ptr = barrier.get_ptr() with uniform_ctx(): if arrive: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 48a9d0b47d2b..4b811a4326cd 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -84,32 +84,21 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): def body(*idx): dst_idx = idx if swizzle is not None: - if swizzle != 128: - raise NotImplementedError("Only swizzle 128B implemented") - # TODO(apaszke): This can probably be cleaned up. - # But it works and it's test-only, so it doesn't matter much. - # After all, swizzle should just be an xor of row and linear idx, - # adjusted for the bytewidth. + assert swizzle.bit_count() == 1 bytes_per_element = bytewidth(src_ty.element_type) - elems_per_tile = 1024 // bytes_per_element - elems_per_row = elems_per_tile // 8 - elems_per_group = 16 // bytes_per_element linear_idx = c(0, index) for stride, i in zip(dyn_strides, idx): linear_idx = arith.addi(linear_idx, arith.muli(i, stride)) - tile_offset = arith.remui(linear_idx, c(elems_per_tile, index)) - linear_tile_start = arith.subi(linear_idx, tile_offset) - row = arith.divui(tile_offset, c(elems_per_row, index)) - row_offset = arith.remui(tile_offset, c(elems_per_row, index)) - src_group = arith.divui(row_offset, c(elems_per_group, index)) - group_offset = arith.remui(row_offset, c(elems_per_group, index)) - dst_group = arith.xori(src_group, row) - dst_linear_idx = mlir_sum([ - linear_tile_start, - arith.muli(row, c(elems_per_row, index)), - arith.muli(dst_group, c(elems_per_group, index)), - group_offset, - ]) + # Swizzle pattern repeats every 128 bytes. + swizzle_src = arith.remui( + arith.divui(linear_idx, c(128 // bytes_per_element, index)), + c(swizzle // 16, index), + ) + # Swizzle happens in groups of 16 bytes. + swizzle_shift = 4 - (bytes_per_element.bit_length() - 1) + dst_linear_idx = arith.xori( + linear_idx, arith.shli(swizzle_src, c(swizzle_shift, index)) + ) dst_idx = [ arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index)) for stride, bound in zip(dyn_strides, shape) @@ -714,16 +703,17 @@ def kernel(ctx, dst, tmp): )() np.testing.assert_array_equal(y, np.full_like(y, 3, dtype=np.int32)) + class TMATest(TestCase): @parameterized.product( - swizzle=(None, 128), - shape=((64, 64), (5, 64), (2, 3, 5, 64)), + swizzle=(None, 32, 64, 128), + shape=((64, None), (5, None), (2, 3, 5, None)), dtype=(jnp.float16, jnp.float32), ) def test_tma_load_basic(self, swizzle, shape, dtype): - if dtype == jnp.float32: - shape = (*shape[:-1], shape[-1] // 2) + minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize + shape = (*shape[:-1], minor_size) i1 = ir.IntegerType.get_signless(1) def kernel(ctx, src, dst, tmp): barrier = BarrierArray(1)[0] @@ -740,9 +730,11 @@ def kernel(ctx, src, dst, tmp): dtype=(jnp.float16, jnp.float32), ) def test_tma_load_tiled(self, swizzle, shape, dtype): + # TODO(apaszke): ptxas seems to freeze when generating code for copy with + # swizzle 32 and 64. i1 = ir.IntegerType.get_signless(1) index = ir.IndexType.get() - tiling = (32, 128 // jnp.dtype(dtype).itemsize) + tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize) tiled_shape = tile_shape(shape, tiling)[:len(shape)] def kernel(ctx, src, dst, tmp): barrier = BarrierArray(1)[0] @@ -772,9 +764,10 @@ def kernel(ctx, src, dst, tmp): dtype=(jnp.float16, jnp.float32), ) def test_tma_squeeze_indexing(self, swizzle, dtype): - shape = (4, 5, 64) - if dtype == jnp.float32: - shape = (*shape[:-1], shape[-1] // 2) + # TODO(apaszke): ptxas seems to freeze when generating code for copy with + # swizzle 32 and 64. + minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize + shape = (4, 5, minor_size) def kernel(ctx, src, dst, tmp): barrier = BarrierArray(1)[0] for i in range(4): @@ -810,13 +803,13 @@ def kernel(ctx, src, dst, tmp): np.testing.assert_array_equal(y, x) @parameterized.product( - swizzle=(None, 128), - shape=((64, 64), (5, 64), (2, 3, 5, 64)), + swizzle=(None, 32, 64, 128), + shape=((64, None), (5, None), (2, 3, 5, None)), dtype=(jnp.float16, jnp.float32), ) def test_tma_store(self, swizzle, shape, dtype): - if dtype == jnp.float32: - shape = (*shape[:-1], shape[-1] // 2) + minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize + shape = (*shape[:-1], minor_size) def kernel(ctx, src, dst, tmp): copy(src, tmp, swizzle=swizzle) ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle)