Skip to content

Commit

Permalink
sharding cleanup: use inline checks for unimplemented and auto
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 23, 2024
1 parent df6e5e7 commit ce13a14
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 82 deletions.
4 changes: 2 additions & 2 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,9 @@ def _export_lowered(
nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> HloSharding | None:
if sharding_impls.is_unspecified(s):
if isinstance(s, sharding_impls.UnspecifiedValue):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
return s._to_xla_hlo_sharding(aval.ndim)

all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,
Expand Down
55 changes: 28 additions & 27 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
UnspecifiedValue, get_array_mapping as _get_array_mapping,
array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
Expand Down Expand Up @@ -149,7 +149,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args,

@lru_cache(maxsize=2048)
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None or is_unspecified(sharding):
if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue):
return True
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
dtypes.issubdtype(aval.dtype, dtypes.extended)):
Expand Down Expand Up @@ -1643,7 +1643,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
def check_if_any_auto(
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
for s in shardings:
if is_auto(s):
if isinstance(s, AUTO):
return True
return False

Expand Down Expand Up @@ -1727,14 +1727,14 @@ def _get_and_check_device_assignment(
devices = tuple(devices)

for i, s_type, source_info in shardings:
if is_unspecified(i):
if isinstance(i, UnspecifiedValue):
continue

if first_sharding_info is None:
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
else (i._device_assignment, s_type, source_info)) # type: ignore
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
else (i._device_assignment, s_type, source_info))
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
Expand Down Expand Up @@ -1836,7 +1836,8 @@ class SemanticallyEqualShardings:
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
avals: tuple[core.AbstractValue]):
gspmd_shardings = [
s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore
s if isinstance(s, (UnspecifiedValue, AUTO))
else to_gspmd_sharding(s, a.ndim)
for s, a in zip(shardings, avals)]
self._gspmd_shardings = gspmd_shardings
self.shardings = shardings
Expand Down Expand Up @@ -2004,7 +2005,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
except:
return True
for i in shardings:
if is_unspecified_or_auto(i):
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if i.memory_kind is None: # pytype: disable=attribute-error
continue
Expand Down Expand Up @@ -2034,7 +2035,7 @@ def _default_rule(prim, num_outvars, *_, **__):
if in_shardings is None:
invar_mem_kind = [None] * len(jaxpr.invars)
else:
invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind
invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind
for s in in_shardings]
safe_map(write, jaxpr.invars, invar_mem_kind)
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))
Expand Down Expand Up @@ -2129,7 +2130,7 @@ def _abstract_to_concrete_mesh(abstract_mesh):

out = []
for s, a in zip(shardings, avals):
if is_unspecified(s) and a.sharding is not None:
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
Expand Down Expand Up @@ -2216,9 +2217,9 @@ def lower_sharding_computation(
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in unique_in_shardings) or
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
any(not is_unspecified(o) for o in unique_out_shardings))
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))

da_object = _create_da_object(tuple(device_assignment))

Expand Down Expand Up @@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings(
new_in_shardings = []
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
global_in_avals):
if is_unspecified(orig):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
Expand Down Expand Up @@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings(
new_out_shardings = []
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if is_unspecified(orig):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
Expand Down Expand Up @@ -2839,16 +2840,16 @@ def from_hlo(name: str,
da = _create_da_object(tuple(device_assignment))
del device_assignment

allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i)
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
for i in in_shardings)
allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o)
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
for o in out_shardings)

mesh = None
if auto_spmd_lowering:
for i in it.chain.from_iterable([in_shardings, out_shardings]):
if is_auto(i):
mesh = i.mesh # type: ignore
if isinstance(i, AUTO):
mesh = i.mesh
break

xla_executable = _cached_compilation(
Expand All @@ -2861,9 +2862,9 @@ def from_hlo(name: str,
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
in_shardings = [x if isinstance(i, AUTO) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings = [x if is_auto(o) else o
out_shardings = [x if isinstance(o, AUTO) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
else:
if pmap_nreps == 1:
Expand Down Expand Up @@ -2954,8 +2955,8 @@ def contains_explicit_attributes(self):
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or
any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))

Expand Down Expand Up @@ -3130,7 +3131,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):

def check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if is_unspecified(i) or is_auto(i):
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if getattr(i, '_device_backend', False):
return True
Expand All @@ -3156,7 +3157,7 @@ def check_array_xla_sharding_layout_match(
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
if isinstance(xs, (UnspecifiedValue, AUTO)):
continue

db_xs = check_device_backend_on_shardings([xs])
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.sharding_impls import AUTO as AutoSharding
from jax._src.lib import xla_client as xc

Shape = tuple[int, ...]
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, device_local_layout: LayoutOptions = None,
sharding: ShardingOptions = None):
# If layout is concrete and sharding is not, error.
if (isinstance(device_local_layout, DeviceLocalLayout) and
(sharding is None or is_auto(sharding))):
(sharding is None or isinstance(sharding, AutoSharding))):
raise ValueError(
'Sharding has to be concrete when layout is of type'
f' {type(device_local_layout)}. Please pass a'
Expand Down
Loading

0 comments on commit ce13a14

Please sign in to comment.