From 24ad445e99eb1facd799260ad3cdb1f1b63a9c95 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 3 Oct 2023 11:51:48 -0700 Subject: [PATCH] [Pallas] Add support for pytrees in scalar prefetch PiperOrigin-RevId: 570453699 --- jax/_src/pallas/core.py | 19 +++++++++++-------- jax/_src/pallas/mosaic/core.py | 11 ++++++++--- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 6b822f5371fa..1825f4f6e05f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -21,6 +21,7 @@ import functools from typing import Any, Callable, Iterator +from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import state @@ -136,20 +137,20 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid: def _convert_block_spec_to_block_mapping( in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None, - aval: jax_core.ShapedArray, + aval: jax_core.ShapedArray, in_tree: Any, ) -> BlockSpec | None: if block_spec is no_block_spec: return None if block_spec.index_map is None: - compute_index = lambda *args: (0,) * len(aval.shape) + compute_index = lambda *args, **kwargs: (0,) * len(aval.shape) block_shape = aval.shape else: compute_index = block_spec.compute_index block_shape = block_spec.block_shape block_shape = tuple( mapped if s is None else s for s in block_shape) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(compute_index), in_avals) + flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.memory_space) @@ -249,12 +250,14 @@ def get_grid_mapping( self.grid, in_avals, flat_in_specs, out_avals, flat_out_specs) grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) + # Create args, kwargs pytree def + grid_tree = tree_util.tree_structure((tuple(grid_avals), {})) in_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs, - in_ref_avals) + partial(_convert_block_spec_to_block_mapping, grid_avals, + in_tree=grid_tree), in_specs, in_ref_avals) out_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs, - out_ref_avals) + partial(_convert_block_spec_to_block_mapping, grid_avals, + in_tree=grid_tree), out_specs, out_ref_avals) grid_mapping = GridMapping( self.grid, (*in_block_mappings, *out_block_mappings), (), num_index_operands=0) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index aef0800bc66a..a0822117baf4 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -126,7 +126,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): in_specs_tree: Any out_specs_tree: Any - def __init__( self, num_scalar_prefetch: int, @@ -160,12 +159,18 @@ def get_grid_mapping( state.shaped_array_ref(aval.shape, aval.dtype) for aval in flat_scalar_avals] grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) + # Create args, kwargs pytree def + index_map_in_tree = tree_util.tree_structure( + ((*grid_avals, *scalar_avals), {}) + ) in_block_mappings = map( partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals)), in_specs, in_ref_avals) + (*grid_avals, *scalar_ref_avals), + in_tree=index_map_in_tree), in_specs, in_ref_avals) out_block_mappings = map( partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals)), out_specs, out_ref_avals) + (*grid_avals, *scalar_ref_avals), + in_tree=index_map_in_tree), out_specs, out_ref_avals) grid_mapping = GridMapping( grid=self.grid, block_mappings=(*in_block_mappings, *out_block_mappings),