From 2a1300d2f913fd9c6bf8621cc866a72c03ec3d56 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 5 Apr 2024 20:08:48 -0700 Subject: [PATCH] Add `Layout` support to `jax.jit`. `jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere. Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding. Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU). PiperOrigin-RevId: 622352537 --- oryx/core/interpreters/harvest.py | 35 ++++++++++++++++++----------- oryx/core/interpreters/propagate.py | 4 ++++ oryx/core/ppl/effect_handler.py | 4 ++++ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index cf1bb46..4d4c95b 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -1041,7 +1041,8 @@ def _calc_extra_inps(num_consts, params): in_shardings = ( sharding_impls.UNSPECIFIED,) * num_consts + params['in_shardings'] donated_invars = (False,) * num_consts + params['donated_invars'] - return in_shardings, donated_invars + in_layouts = (None,) * num_consts + params['in_layouts'] + return in_shardings, donated_invars, in_layouts def _reap_pjit_rule(trace, *tracers, **params): @@ -1078,14 +1079,18 @@ def _reap_pjit_rule(trace, *tracers, **params): reap_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr( flat_fun, tuple(t.aval for t in tracers)) - in_shardings, donated_invars = _calc_extra_inps( + in_shardings, donated_invars, in_layouts = _calc_extra_inps( len(final_consts), params) - new_params = {**params, - 'jaxpr': reap_jaxpr, - 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), - 'in_shardings': in_shardings, - 'donated_invars': donated_invars} + new_params = { + **params, + 'jaxpr': reap_jaxpr, + 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), + 'in_shardings': in_shardings, + 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': (None,) * len(out_avals) + } outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) outvals = jax_util.safe_map(trace.pure, outvals) @@ -1500,14 +1505,18 @@ def _plant_pjit_rule(trace, *tracers, **params): planted_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr( flat_fun, tuple(t.aval for t in tracers)) - in_shardings, donated_invars = _calc_extra_inps( + in_shardings, donated_invars, in_layouts = _calc_extra_inps( len(final_consts), params) - new_params = {**params, - 'jaxpr': planted_jaxpr, - 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), - 'in_shardings': in_shardings, - 'donated_invars': donated_invars} + new_params = { + **params, + 'jaxpr': planted_jaxpr, + 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), + 'in_shardings': in_shardings, + 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': (None,) * len(out_avals), + } outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) return jax_util.safe_map(trace.pure, outvals) diff --git a/oryx/core/interpreters/propagate.py b/oryx/core/interpreters/propagate.py index 1a91075..bcf5b3b 100644 --- a/oryx/core/interpreters/propagate.py +++ b/oryx/core/interpreters/propagate.py @@ -384,6 +384,8 @@ def _pjit_propagate_rule(incells, outcells, **params): in_shardings = (sharding_impls.UNSPECIFIED,) * len(flat_vals) donated_invars = (False,) * len(flat_vals) out_shardings = (sharding_impls.UNSPECIFIED,) * len(new_jaxpr.out_avals) + in_layouts = (None,) * len(flat_vals) + out_layouts = (None,) * len(new_jaxpr.out_avals) new_params = { **params, @@ -391,6 +393,8 @@ def _pjit_propagate_rule(incells, outcells, **params): 'in_shardings': in_shardings, 'out_shardings': out_shardings, 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': out_layouts, } flat_out = pjit.pjit_p.bind(*flat_vals, **new_params) return tree_util.tree_unflatten(out_tree(), flat_out) diff --git a/oryx/core/ppl/effect_handler.py b/oryx/core/ppl/effect_handler.py index 905d5d4..79eefcc 100644 --- a/oryx/core/ppl/effect_handler.py +++ b/oryx/core/ppl/effect_handler.py @@ -276,6 +276,8 @@ def _pjit_effect_handler_rule(rules, state, invals, **params): out_shardings = (sharding_impls.UNSPECIFIED,) * num_state + params[ 'out_shardings' ] + in_layouts = (None,) * num_state + params['in_layouts'] + out_layouts = (None,) * num_state + params['out_layouts'] new_params = { **params, @@ -283,6 +285,8 @@ def _pjit_effect_handler_rule(rules, state, invals, **params): 'in_shardings': in_shardings, 'out_shardings': out_shardings, 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': out_layouts, } ans_state = pjit.pjit_p.bind(*state_invals, **new_params)