Skip to content

Commit

Permalink
Add Layout support to jax.jit.
Browse files Browse the repository at this point in the history
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.

Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

PiperOrigin-RevId: 621747200
  • Loading branch information
yashk2810 authored and The oryx Authors committed Apr 6, 2024
1 parent b59ab02 commit 08d40f5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
35 changes: 22 additions & 13 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,8 @@ def _calc_extra_inps(num_consts, params):
in_shardings = (
sharding_impls.UNSPECIFIED,) * num_consts + params['in_shardings']
donated_invars = (False,) * num_consts + params['donated_invars']
return in_shardings, donated_invars
in_layouts = (None,) * num_consts + params['in_layouts']
return in_shardings, donated_invars, in_layouts


def _reap_pjit_rule(trace, *tracers, **params):
Expand Down Expand Up @@ -1078,14 +1079,18 @@ def _reap_pjit_rule(trace, *tracers, **params):

reap_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr(
flat_fun, tuple(t.aval for t in tracers))
in_shardings, donated_invars = _calc_extra_inps(
in_shardings, donated_invars, in_layouts = _calc_extra_inps(
len(final_consts), params)

new_params = {**params,
'jaxpr': reap_jaxpr,
'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals),
'in_shardings': in_shardings,
'donated_invars': donated_invars}
new_params = {
**params,
'jaxpr': reap_jaxpr,
'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals),
'in_shardings': in_shardings,
'donated_invars': donated_invars,
'in_layouts': in_layouts,
'out_layouts': (None,) * len(out_avals)
}
outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params)

outvals = jax_util.safe_map(trace.pure, outvals)
Expand Down Expand Up @@ -1500,14 +1505,18 @@ def _plant_pjit_rule(trace, *tracers, **params):

planted_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr(
flat_fun, tuple(t.aval for t in tracers))
in_shardings, donated_invars = _calc_extra_inps(
in_shardings, donated_invars, in_layouts = _calc_extra_inps(
len(final_consts), params)

new_params = {**params,
'jaxpr': planted_jaxpr,
'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals),
'in_shardings': in_shardings,
'donated_invars': donated_invars}
new_params = {
**params,
'jaxpr': planted_jaxpr,
'out_shardings': (sharding_impls.UNSPECIFIED,) * len(out_avals),
'in_shardings': in_shardings,
'donated_invars': donated_invars,
'in_layouts': in_layouts,
'out_layouts': (None,) * len(out_avals),
}
outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params)

return jax_util.safe_map(trace.pure, outvals)
Expand Down
4 changes: 4 additions & 0 deletions oryx/core/interpreters/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,17 @@ def _pjit_propagate_rule(incells, outcells, **params):
in_shardings = (sharding_impls.UNSPECIFIED,) * len(flat_vals)
donated_invars = (False,) * len(flat_vals)
out_shardings = (sharding_impls.UNSPECIFIED,) * len(new_jaxpr.out_avals)
in_layouts = (None,) * len(flat_vals)
out_layouts = (None,) * len(new_jaxpr.out_avals)

new_params = {
**params,
'jaxpr': new_jaxpr,
'in_shardings': in_shardings,
'out_shardings': out_shardings,
'donated_invars': donated_invars,
'in_layouts': in_layouts,
'out_layouts': out_layouts,
}
flat_out = pjit.pjit_p.bind(*flat_vals, **new_params)
return tree_util.tree_unflatten(out_tree(), flat_out)
4 changes: 4 additions & 0 deletions oryx/core/ppl/effect_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,17 @@ def _pjit_effect_handler_rule(rules, state, invals, **params):
out_shardings = (sharding_impls.UNSPECIFIED,) * num_state + params[
'out_shardings'
]
in_layouts = (None,) * num_state + params['in_layouts']
out_layouts = (None,) * num_state + params['out_layouts']

new_params = {
**params,
'jaxpr': new_jaxpr,
'in_shardings': in_shardings,
'out_shardings': out_shardings,
'donated_invars': donated_invars,
'in_layouts': in_layouts,
'out_layouts': out_layouts,
}

ans_state = pjit.pjit_p.bind(*state_invals, **new_params)
Expand Down

0 comments on commit 08d40f5

Please sign in to comment.