Skip to content

Commit

Permalink
pjit cleanup: inline checks for unspecified value
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 22, 2024
1 parent e4f3f8f commit be7ff61
Showing 1 changed file with 29 additions and 31 deletions.
60 changes: 29 additions & 31 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, get_single_pspec, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
from jax._src.traceback_util import api_boundary
Expand Down Expand Up @@ -418,10 +417,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if in_shardings is not None and not is_unspecified(in_shardings):
if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.')
if out_shardings is not None and not is_unspecified(out_shardings):
if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.')

Expand All @@ -440,7 +439,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')

user_specified_in_shardings = (in_shardings is not None and
not is_unspecified(in_shardings))
not isinstance(in_shardings, UnspecifiedValue))

in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
Expand Down Expand Up @@ -483,7 +482,7 @@ def lower(*args, **kwargs):
@api_boundary
def eval_shape(*args, **kwargs):
p, _ = _infer_params(fun, jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
weak_type=x.weak_type)
Expand Down Expand Up @@ -1001,7 +1000,7 @@ def hashable_pytree(pytree):
def _create_sharding_for_array(mesh, x, name, api_name):
if x is None and (mesh is None or mesh.empty):
return UNSPECIFIED
if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x):
if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)):
return x
if mesh is None:
msg = ('jax.jit only supports `Sharding`s being passed to'
Expand Down Expand Up @@ -1110,7 +1109,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
# Only do this if original in_shardings are unspecified. If it is AUTO, go
# via flatten_axis_resources.
if is_unspecified(orig_in_shardings):
if isinstance(orig_in_shardings, UnspecifiedValue):
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
else:
in_shardings_flat = flatten_axis_resources(
Expand Down Expand Up @@ -1312,7 +1311,7 @@ def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if (is_unspecified(orig_out_shardings) or
if (isinstance(orig_out_shardings, UnspecifiedValue) or
isinstance(orig_out_shardings, sharding.Sharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
else:
Expand Down Expand Up @@ -1391,7 +1390,7 @@ def pjit_check_aval_sharding(
what_aval: str, allow_uneven_sharding: bool):
new_names = [''] * len(shardings) if names is None else names
for aval, s, name in zip(flat_avals, shardings, new_names):
if is_unspecified_or_auto(s):
if isinstance(s, (UnspecifiedValue, AUTO)):
continue
name_str = f' with pytree key path {name}' if name else ''
shape = aval.shape
Expand Down Expand Up @@ -1466,7 +1465,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
else:
arg_layout, dispatch_arg_layout = None, None
# Sharding can be unspecified when array is committed if it's a PmapSharding.
is_pmap_sharding = (is_unspecified(rs) or
is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or
isinstance(getattr(arg, 'sharding', None), PmapSharding))
if jit_in_l is None:
if committed:
Expand Down Expand Up @@ -1527,15 +1526,15 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
if getattr(a, '_committed', True):
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))

resolved_in_shardings = []
resolved_in_shardings: list[PjitSharding] = []
for arg, pjit_in_s in zip(args, pjit_in_shardings):
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
# not allow None as the sharding.
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
if hasattr(arg, 'sharding') and arg.sharding is not None
else (UNSPECIFIED, False))
if is_unspecified(pjit_in_s):
if is_unspecified(arg_s):
if isinstance(pjit_in_s, UnspecifiedValue):
if isinstance(arg_s, UnspecifiedValue):
resolved_in_shardings.append(arg_s)
else:
if committed:
Expand All @@ -1553,7 +1552,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
'multiple devices is not supported.')
else:
if (isinstance(arg, np.ndarray) and
not pjit_in_s.is_fully_replicated and # type: ignore
not pjit_in_s.is_fully_replicated and # type: ignore[union-attr]
xb.process_count() > 1):
raise ValueError(
'Passing non-trivial shardings for numpy '
Expand All @@ -1572,16 +1571,16 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr]
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr]
f'arg memory kind: {arg_s.memory_kind} for '
f'arg shape: {shaped_abstractify(arg).str_short()}')
if (committed and
not isinstance(arg_s, PmapSharding) and
not op_shardings.are_op_shardings_equal(
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr]
arg_s._to_xla_hlo_sharding(arg.ndim))):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
Expand Down Expand Up @@ -1780,8 +1779,8 @@ def pjit_staging_rule(trace, *args, **params):
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)
if (params["inline"] and
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"]) and
all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and
all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
if config.dynamic_shapes.value:
Expand Down Expand Up @@ -1830,7 +1829,7 @@ def pjit_staging_rule(trace, *args, **params):

def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr)
in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol
in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol
in zip(in_fwd, out_shardings, out_layouts)]
keep = [f is None for f in in_fwd]
jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep)
Expand Down Expand Up @@ -1896,8 +1895,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,

func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings]
result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
Expand Down Expand Up @@ -1990,9 +1989,9 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
def _pjit_batcher_for_sharding(
s: sharding.Sharding | UnspecifiedValue,
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
if is_unspecified(s):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore
hlo_s = s._to_xla_hlo_sharding(ndim)
if spmd_axis_name is None:
if sharding_impls.is_op_sharding_replicated(hlo_s):
return s
Expand All @@ -2004,7 +2003,7 @@ def _pjit_batcher_for_sharding(
tad.insert(dim, 1)
new_op.tile_assignment_dimensions = tad
new_gs = GSPMDSharding(
s._device_assignment, new_op, # type: ignore
s._device_assignment, new_op,
_device_list=getattr(s, '_internal_device_list', None))
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0]
else:
Expand Down Expand Up @@ -2107,7 +2106,7 @@ def keep_where(l, should_keep):
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
in_fwd = [
fwd if is_unspecified(os) and ol is None else None
fwd if isinstance(os, UnspecifiedValue) and ol is None else None
for os, ol, fwd in zip(
keep_where(out_shardings, known_outs),
keep_where(out_layouts, known_outs), in_fwd_primal)
Expand Down Expand Up @@ -2358,9 +2357,9 @@ def _pjit_pp_rule(eqn, context, settings):
del params['inline']
if not any(params['donated_invars']):
del params['donated_invars']
if all(is_unspecified(s) for s in params['in_shardings']):
if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']):
del params['in_shardings']
if all(is_unspecified(s) for s in params['out_shardings']):
if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']):
del params['out_shardings']
if all(l is None for l in params['in_layouts']):
del params['in_layouts']
Expand All @@ -2382,8 +2381,7 @@ def _pjit_pp_rule(eqn, context, settings):
def _pjit_state_discharge_rule(
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, **params):
if not (all(map(is_unspecified, in_shardings)) and
all(map(is_unspecified, out_shardings))):
if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)):
raise NotImplementedError

if not (all(l is None for l in in_layouts) and
Expand Down

0 comments on commit be7ff61

Please sign in to comment.