Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sharding cleanup: use inline checks for unimplemented and auto #24470

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,9 @@ def _export_lowered(
nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> HloSharding | None:
if sharding_impls.is_unspecified(s):
if isinstance(s, sharding_impls.UnspecifiedValue):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
return s._to_xla_hlo_sharding(aval.ndim)

all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,
Expand Down
55 changes: 28 additions & 27 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
UnspecifiedValue, get_array_mapping as _get_array_mapping,
array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
Expand Down Expand Up @@ -149,7 +149,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args,

@lru_cache(maxsize=2048)
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None or is_unspecified(sharding):
if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue):
return True
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
dtypes.issubdtype(aval.dtype, dtypes.extended)):
Expand Down Expand Up @@ -1643,7 +1643,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
def check_if_any_auto(
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
for s in shardings:
if is_auto(s):
if isinstance(s, AUTO):
return True
return False

Expand Down Expand Up @@ -1727,14 +1727,14 @@ def _get_and_check_device_assignment(
devices = tuple(devices)

for i, s_type, source_info in shardings:
if is_unspecified(i):
if isinstance(i, UnspecifiedValue):
continue

if first_sharding_info is None:
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
else (i._device_assignment, s_type, source_info)) # type: ignore
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
else (i._device_assignment, s_type, source_info))
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
Expand Down Expand Up @@ -1836,7 +1836,8 @@ class SemanticallyEqualShardings:
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
avals: tuple[core.AbstractValue]):
gspmd_shardings = [
s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore
s if isinstance(s, (UnspecifiedValue, AUTO))
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
for s, a in zip(shardings, avals)]
self._gspmd_shardings = gspmd_shardings
self.shardings = shardings
Expand Down Expand Up @@ -2004,7 +2005,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
except:
return True
for i in shardings:
if is_unspecified_or_auto(i):
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if i.memory_kind is None: # pytype: disable=attribute-error
continue
Expand Down Expand Up @@ -2034,7 +2035,7 @@ def _default_rule(prim, num_outvars, *_, **__):
if in_shardings is None:
invar_mem_kind = [None] * len(jaxpr.invars)
else:
invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind
invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind
for s in in_shardings]
safe_map(write, jaxpr.invars, invar_mem_kind)
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))
Expand Down Expand Up @@ -2129,7 +2130,7 @@ def _abstract_to_concrete_mesh(abstract_mesh):

out = []
for s, a in zip(shardings, avals):
if is_unspecified(s) and a.sharding is not None:
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
Expand Down Expand Up @@ -2216,9 +2217,9 @@ def lower_sharding_computation(
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in unique_in_shardings) or
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
any(not is_unspecified(o) for o in unique_out_shardings))
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))

da_object = _create_da_object(tuple(device_assignment))

Expand Down Expand Up @@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings(
new_in_shardings = []
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
global_in_avals):
if is_unspecified(orig):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
Expand Down Expand Up @@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings(
new_out_shardings = []
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if is_unspecified(orig):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
Expand Down Expand Up @@ -2839,16 +2840,16 @@ def from_hlo(name: str,
da = _create_da_object(tuple(device_assignment))
del device_assignment

allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i)
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
for i in in_shardings)
allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o)
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
for o in out_shardings)

mesh = None
if auto_spmd_lowering:
for i in it.chain.from_iterable([in_shardings, out_shardings]):
if is_auto(i):
mesh = i.mesh # type: ignore
if isinstance(i, AUTO):
mesh = i.mesh
break

xla_executable = _cached_compilation(
Expand All @@ -2861,9 +2862,9 @@ def from_hlo(name: str,
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
in_shardings = [x if isinstance(i, AUTO) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings = [x if is_auto(o) else o
out_shardings = [x if isinstance(o, AUTO) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
else:
if pmap_nreps == 1:
Expand Down Expand Up @@ -2954,8 +2955,8 @@ def contains_explicit_attributes(self):
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or
any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))

Expand Down Expand Up @@ -3130,7 +3131,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):

def check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if is_unspecified(i) or is_auto(i):
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if getattr(i, '_device_backend', False):
return True
Expand All @@ -3156,7 +3157,7 @@ def check_array_xla_sharding_layout_match(
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
if isinstance(xs, (UnspecifiedValue, AUTO)):
continue

db_xs = check_device_backend_on_shardings([xs])
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.sharding_impls import AUTO as AutoSharding
from jax._src.lib import xla_client as xc

Shape = tuple[int, ...]
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, device_local_layout: LayoutOptions = None,
sharding: ShardingOptions = None):
# If layout is concrete and sharding is not, error.
if (isinstance(device_local_layout, DeviceLocalLayout) and
(sharding is None or is_auto(sharding))):
(sharding is None or isinstance(sharding, AutoSharding))):
raise ValueError(
'Sharding has to be concrete when layout is of type'
f' {type(device_local_layout)}. Please pass a'
Expand Down
Loading
Loading