From 701c63e19a712143868300f570e8b8a02973e3cd Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 17 Jun 2024 17:27:15 -0700 Subject: [PATCH] [Pallas/TPU] Add API for megacore partitioning of pipelines PiperOrigin-RevId: 644184524 --- jax/_src/pallas/mosaic/pipeline.py | 99 ++++++++++++++++++++++-- jax/experimental/pallas/tpu.py | 2 + tests/pallas/pallas_pipeline_tpu_test.py | 87 ++++++++++++++++++++- 3 files changed, 179 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index df501bb7d9a8..01ef76a8b631 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -14,6 +14,8 @@ """Module for emitting custom TPU pipelines within a Pallas call.""" +from __future__ import annotations + import dataclasses import enum import functools @@ -24,6 +26,7 @@ import jax from jax import lax from jax import tree_util +from jax._src import util as jax_util from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives @@ -87,15 +90,16 @@ def _grid_size(grid): return size -def _get_indices(step, grid): +def _get_indices(step, grid, offsets): """Get indices for a given step and grid.""" extended_grid = grid + (1,) strides = tuple( itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] - return tuple( + indices = tuple( lax.div(lax.rem(step, a), b) for a, b in zip(strides[:-1], strides[1:]) ) + return tuple(a + b for a, b in zip(indices, offsets, strict=True)) class BufferType(enum.Enum): @@ -350,8 +354,9 @@ class Scheduler: """Sequences input and output copies and waits for a pipeline.""" def __init__(self, - step, - grid, + step: jax.Array, + grid: tuple[int | jax.Array, ...], + grid_offsets: tuple[int | jax.Array, ...], first_cycle=None, last_cycle=None, init_accumulators=None, @@ -361,6 +366,7 @@ def __init__(self, Args: step: inner step number. grid: pallas grid for BufferedRefs. + grid_offsets: offsets for grid indices (used for megacore). first_cycle: whether this is the first invocation of the pipeline. last_cycle: whether this is the last invocation of the pipeline. init_accumulators: do we zero-initialize accumulator state for this @@ -388,9 +394,13 @@ def __init__(self, self.next_step = _mod(step + 1, self.num_steps) # Derived grid indices for present, previous, and next steps. - self.indices = _get_indices(step, grid) - self.prev_indices = _get_indices(self.prev_step, self.grid) - self.next_indices = _get_indices(self.next_step, self.grid) + self.indices = _get_indices(step, grid, grid_offsets) + self.prev_indices = _get_indices( + self.prev_step, grid, grid_offsets + ) + self.next_indices = _get_indices( + self.next_step, grid, grid_offsets + ) def grid_env(self): return pallas_core.grid_env( @@ -628,13 +638,79 @@ def make_output_bref(out_spec, out_ref, accumulate): return (*in_brefs, *out_brefs) +class GridDimensionSemantics: + pass +PARALLEL = GridDimensionSemantics() +ARBITRARY = GridDimensionSemantics() + + +def _partition_grid( + grid: tuple[int | jax.Array, ...], + core_axis: int | None, + dimension_semantics: tuple[GridDimensionSemantics, ...] | None, +) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]: + if core_axis is None: + # We aren't partitioning the grid + return grid, (0,) * len(grid) + num_cores = pl.num_programs(core_axis) + # Check that num_cores is statically known + if not isinstance(num_cores, int): + raise NotImplementedError( + f"Cannot partition grid over dynamic number of cores: {core_axis=}" + ) + if num_cores == 1: + # We aren't partitioning the grid + return grid, (0,) * len(grid) + + # If dimension_semantics aren't provided, we assume it is all arbitrary. + if dimension_semantics is None: + dimension_semantics = (ARBITRARY,) * len(grid) + if len(dimension_semantics) != len(grid): + raise ValueError("dimension_semantics must be the same length as grid.") + + parallel_dimensions = {i for i, d in enumerate(dimension_semantics) + if d == PARALLEL} + # If there are no parallel dimensions, we can't partition the grid + if not parallel_dimensions: + # TODO(sharadmv): enable running kernel on just one core + raise NotImplementedError( + "Cannot partition over cores without parallel grid dimensions:" + f" {dimension_semantics=}" + ) + + # Try to find a divisible dimension to partition the grid on + divisible_dimensions = { + i for i in parallel_dimensions + if isinstance(grid[i], int) and grid[i] % num_cores == 0 + } + if not divisible_dimensions: + # TODO(sharadmv): enable uneven grid partitioning + raise NotImplementedError( + f"Uneven partitioning of grid not supported: {grid=}, {num_cores=}" + ) + first_divisible_dimension, *_ = [ + i for i in range(len(dimension_semantics)) if i in divisible_dimensions + ] + partitioned_dim_size = grid[first_divisible_dimension] // num_cores + partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size + new_grid = jax_util.tuple_update( + grid, first_divisible_dimension, partitioned_dim_size + ) + offsets = jax_util.tuple_update( + (0,) * len(grid), first_divisible_dimension, partitioned_dim_offset + ) + return new_grid, offsets + + def emit_pipeline( body, *, - grid, + grid: tuple[int | jax.Array, ...], in_specs=None, out_specs=None, should_accumulate_out=False, + core_axis: int | None = None, + dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None ): """Creates a function to emit a manual pallas pipeline. @@ -653,7 +729,13 @@ def emit_pipeline( out_specs: output pallas block specs should_accumulate_out: booleans to indicate which outputs should be treated as accumulators. + core_axis: optional int, indicates whether or not to partition the grid + along the core axis. + dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL + or ARBITRARY). """ + grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics) + num_steps = _grid_size(grid) if not isinstance(in_specs, (list, tuple)): in_specs = (in_specs,) @@ -737,6 +819,7 @@ def loop_body(step, _): scheduler = Scheduler( step, grid, + grid_offsets=grid_offsets, first_cycle=first_cycle, last_cycle=last_cycle, init_accumulators=init_accumulators) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index c0705f67fcfc..ad5fb92719d0 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -26,6 +26,8 @@ from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations +from jax._src.pallas.mosaic.pipeline import ARBITRARY +from jax._src.pallas.mosaic.pipeline import PARALLEL from jax._src.pallas.mosaic.primitives import async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast diff --git a/tests/pallas/pallas_pipeline_tpu_test.py b/tests/pallas/pallas_pipeline_tpu_test.py index 030afbcc9ed8..75be164c655c 100644 --- a/tests/pallas/pallas_pipeline_tpu_test.py +++ b/tests/pallas/pallas_pipeline_tpu_test.py @@ -174,7 +174,7 @@ def emit_pipeline(should_accumulate_out): np.testing.assert_allclose(z, jnp.dot(x, y) + jnp.dot(x, y)) -class PallasCallColectivePipelineTest(parameterized.TestCase): +class PallasCallCollectivePipelineTest(parameterized.TestCase): def setUp(self): if jax.device_count() < 2: @@ -1263,5 +1263,90 @@ def reference(x, y): ) +class PallasCallMegacoreTest(parameterized.TestCase): + + def setUp(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works with TPU v4') + + super().setUp() + + def test_megacore_mul(self): + x = jax.random.uniform(jax.random.key(0), (512, 512)) + + def matmul_pipeline(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + def matmul_kernel(x_ref, y_ref): + pltpu.emit_pipeline( + matmul_pipeline, + grid=(4, 4), + in_specs=[ + pl.BlockSpec(lambda i, j: (i, j), (128, 128)), + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (128, 128)), + core_axis=0, + dimension_semantics=(pltpu.ARBITRARY, pltpu.PARALLEL) + )(x_ref, y_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + np.testing.assert_allclose(func(x), x * 2) + + @parameterized.parameters( + (1024, 1024, 1024, 256, 512, 256), + (768, 1024, 1024, 256, 512, 256), + (1024, 1024, 768, 256, 512, 256), + ) + def test_megacore_matmul(self, m, k, n, bm, bk, bn): + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform(k1, (m, k)) + y = jax.random.uniform(k2, (k, n)) + + def matmul_pipeline(x_ref, y_ref, z_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + z_ref[...] = jnp.zeros_like(z_ref) + z_ref[...] += x_ref[...] @ y_ref[...] + + def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): + m, k = x_ref.shape + _, n = y_ref.shape + pltpu.emit_pipeline( + matmul_pipeline, + grid=(m // bm, n // bn, k // bk), + in_specs=[ + pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)), + pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)), + ], + out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)), + core_axis=0, + dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL, pltpu.ARBITRARY) + )(x_ref, y_ref, z_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + functools.partial(matmul_kernel, bm=bm, bk=bk, bn=bn), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())