diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index cf1bb46..4d4c95b 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -1041,7 +1041,8 @@ def _calc_extra_inps(num_consts, params): in_shardings = ( sharding_impls.UNSPECIFIED,) * num_consts + params['in_shardings'] donated_invars = (False,) * num_consts + params['donated_invars'] - return in_shardings, donated_invars + in_layouts = (None,) * num_consts + params['in_layouts'] + return in_shardings, donated_invars, in_layouts def _reap_pjit_rule(trace, *tracers, **params): @@ -1078,14 +1079,18 @@ def _reap_pjit_rule(trace, *tracers, **params): reap_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr( flat_fun, tuple(t.aval for t in tracers)) - in_shardings, donated_invars = _calc_extra_inps( + in_shardings, donated_invars, in_layouts = _calc_extra_inps( len(final_consts), params) - new_params = {**params, - 'jaxpr': reap_jaxpr, - 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), - 'in_shardings': in_shardings, - 'donated_invars': donated_invars} + new_params = { + **params, + 'jaxpr': reap_jaxpr, + 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), + 'in_shardings': in_shardings, + 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': (None,) * len(out_avals) + } outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) outvals = jax_util.safe_map(trace.pure, outvals) @@ -1500,14 +1505,18 @@ def _plant_pjit_rule(trace, *tracers, **params): planted_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr( flat_fun, tuple(t.aval for t in tracers)) - in_shardings, donated_invars = _calc_extra_inps( + in_shardings, donated_invars, in_layouts = _calc_extra_inps( len(final_consts), params) - new_params = {**params, - 'jaxpr': planted_jaxpr, - 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), - 'in_shardings': in_shardings, - 'donated_invars': donated_invars} + new_params = { + **params, + 'jaxpr': planted_jaxpr, + 'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals), + 'in_shardings': in_shardings, + 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': (None,) * len(out_avals), + } outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) return jax_util.safe_map(trace.pure, outvals) diff --git a/oryx/core/interpreters/propagate.py b/oryx/core/interpreters/propagate.py index 1a91075..bcf5b3b 100644 --- a/oryx/core/interpreters/propagate.py +++ b/oryx/core/interpreters/propagate.py @@ -384,6 +384,8 @@ def _pjit_propagate_rule(incells, outcells, **params): in_shardings = (sharding_impls.UNSPECIFIED,) * len(flat_vals) donated_invars = (False,) * len(flat_vals) out_shardings = (sharding_impls.UNSPECIFIED,) * len(new_jaxpr.out_avals) + in_layouts = (None,) * len(flat_vals) + out_layouts = (None,) * len(new_jaxpr.out_avals) new_params = { **params, @@ -391,6 +393,8 @@ def _pjit_propagate_rule(incells, outcells, **params): 'in_shardings': in_shardings, 'out_shardings': out_shardings, 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': out_layouts, } flat_out = pjit.pjit_p.bind(*flat_vals, **new_params) return tree_util.tree_unflatten(out_tree(), flat_out) diff --git a/oryx/core/ppl/effect_handler.py b/oryx/core/ppl/effect_handler.py index 905d5d4..79eefcc 100644 --- a/oryx/core/ppl/effect_handler.py +++ b/oryx/core/ppl/effect_handler.py @@ -276,6 +276,8 @@ def _pjit_effect_handler_rule(rules, state, invals, **params): out_shardings = (sharding_impls.UNSPECIFIED,) * num_state + params[ 'out_shardings' ] + in_layouts = (None,) * num_state + params['in_layouts'] + out_layouts = (None,) * num_state + params['out_layouts'] new_params = { **params, @@ -283,6 +285,8 @@ def _pjit_effect_handler_rule(rules, state, invals, **params): 'in_shardings': in_shardings, 'out_shardings': out_shardings, 'donated_invars': donated_invars, + 'in_layouts': in_layouts, + 'out_layouts': out_layouts, } ans_state = pjit.pjit_p.bind(*state_invals, **new_params)