diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index b1bb797d538c..99794c8cc23c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -801,9 +801,9 @@ def _export_lowered( nr_devices = len(lowering.compile_args["device_assignment"]) def export_sharding(s: LoweringSharding, aval: core.ShapedArray) -> HloSharding | None: - if sharding_impls.is_unspecified(s): + if isinstance(s, sharding_impls.UnspecifiedValue): return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + return s._to_xla_hlo_sharding(aval.ndim) all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], module_kept_var_idx, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a14ca3dcabd8..f2b0a13883cc 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -68,8 +68,8 @@ from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, - UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, - is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, + UnspecifiedValue, get_array_mapping as _get_array_mapping, + array_mapping_to_axis_resources, SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, @@ -149,7 +149,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, @lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None or is_unspecified(sharding): + if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue): return True if (aval is core.abstract_token or aval.dtype == dtypes.float0 or dtypes.issubdtype(aval.dtype, dtypes.extended)): @@ -1643,7 +1643,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp def check_if_any_auto( shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool: for s in shardings: - if is_auto(s): + if isinstance(s, AUTO): return True return False @@ -1727,14 +1727,14 @@ def _get_and_check_device_assignment( devices = tuple(devices) for i, s_type, source_info in shardings: - if is_unspecified(i): + if isinstance(i, UnspecifiedValue): continue if first_sharding_info is None: first_sharding_info = ( - (i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore - else (i._device_assignment, s_type, source_info)) # type: ignore - arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore + (i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO) + else (i._device_assignment, s_type, source_info)) + arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment if not devices: if first_sharding_info[0] != arr_device_assignment: raise DeviceAssignmentMismatchError([ @@ -1836,7 +1836,8 @@ class SemanticallyEqualShardings: def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], avals: tuple[core.AbstractValue]): gspmd_shardings = [ - s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore + s if isinstance(s, (UnspecifiedValue, AUTO)) + else to_gspmd_sharding(s, a.ndim) for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings @@ -2004,7 +2005,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i): + if isinstance(i, (UnspecifiedValue, AUTO)): continue if i.memory_kind is None: # pytype: disable=attribute-error continue @@ -2034,7 +2035,7 @@ def _default_rule(prim, num_outvars, *_, **__): if in_shardings is None: invar_mem_kind = [None] * len(jaxpr.invars) else: - invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind + invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind for s in in_shardings] safe_map(write, jaxpr.invars, invar_mem_kind) safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) @@ -2129,7 +2130,7 @@ def _abstract_to_concrete_mesh(abstract_mesh): out = [] for s, a in zip(shardings, avals): - if is_unspecified(s) and a.sharding is not None: + if isinstance(s, UnspecifiedValue) and a.sharding is not None: out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), a.sharding.spec)) else: @@ -2216,9 +2217,9 @@ def lower_sharding_computation( committed = bool( devices_from_context or len(device_assignment) > 1 or - any(not is_unspecified(i) for i in unique_in_shardings) or - any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or - any(not is_unspecified(o) for o in unique_out_shardings)) + any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or + any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or + any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings)) da_object = _create_da_object(tuple(device_assignment)) @@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings( new_in_shardings = [] for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings, global_in_avals): - if is_unspecified(orig): + if isinstance(orig, UnspecifiedValue): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) @@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings( new_out_shardings = [] for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings, global_out_avals): - if is_unspecified(orig): + if isinstance(orig, UnspecifiedValue): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) @@ -2839,16 +2840,16 @@ def from_hlo(name: str, da = _create_da_object(tuple(device_assignment)) del device_assignment - allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i) + allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings) - allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o) + allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO)) for o in out_shardings) mesh = None if auto_spmd_lowering: for i in it.chain.from_iterable([in_shardings, out_shardings]): - if is_auto(i): - mesh = i.mesh # type: ignore + if isinstance(i, AUTO): + mesh = i.mesh break xla_executable = _cached_compilation( @@ -2861,9 +2862,9 @@ def from_hlo(name: str, assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( xla_executable, mesh) - in_shardings = [x if is_auto(i) else i + in_shardings = [x if isinstance(i, AUTO) else i for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings = [x if is_auto(o) else o + out_shardings = [x if isinstance(o, AUTO) else o for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: @@ -2954,8 +2955,8 @@ def contains_explicit_attributes(self): self.donate_argnames is not None or self.device is not None or self.backend is not None or - any(not is_unspecified(i) for i in self.in_shardings_leaves) or - any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or + any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or any(i is not None for i in self.in_layouts_leaves) or any(o is not None for o in self.out_layouts_leaves)) @@ -3130,7 +3131,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): def check_device_backend_on_shardings(shardings) -> bool: for i in shardings: - if is_unspecified(i) or is_auto(i): + if isinstance(i, (UnspecifiedValue, AUTO)): continue if getattr(i, '_device_backend', False): return True @@ -3156,7 +3157,7 @@ def check_array_xla_sharding_layout_match( args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): if not isinstance(arg, ArrayImpl): continue - if is_unspecified_or_auto(xs): + if isinstance(xs, (UnspecifiedValue, AUTO)): continue db_xs = check_device_backend_on_shardings([xs]) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 64bbd3268b16..5309f0b1fd9c 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -19,7 +19,7 @@ import numpy as np from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding -from jax._src.sharding_impls import AUTO as AutoSharding, is_auto +from jax._src.sharding_impls import AUTO as AutoSharding from jax._src.lib import xla_client as xc Shape = tuple[int, ...] @@ -101,7 +101,7 @@ def __init__(self, device_local_layout: LayoutOptions = None, sharding: ShardingOptions = None): # If layout is concrete and sharding is not, error. if (isinstance(device_local_layout, DeviceLocalLayout) and - (sharding is None or is_auto(sharding))): + (sharding is None or isinstance(sharding, AutoSharding))): raise ValueError( 'Sharding has to be concrete when layout is of type' f' {type(device_local_layout)}. Please pass a' diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a69e8987b2d8..2abf81f26aa4 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -67,8 +67,7 @@ from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, - ParsedPartitionSpec, get_single_pspec, is_unspecified, - is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) + ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -418,10 +417,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") - if in_shardings is not None and not is_unspecified(in_shardings): + if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue): raise ValueError('If backend or device is specified on jit, then ' 'in_shardings should not be specified.') - if out_shardings is not None and not is_unspecified(out_shardings): + if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue): raise ValueError('If backend or device is specified on jit, then ' 'out_shardings should not be specified.') @@ -440,7 +439,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') user_specified_in_shardings = (in_shardings is not None and - not is_unspecified(in_shardings)) + not isinstance(in_shardings, UnspecifiedValue)) in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings) out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings) @@ -483,7 +482,7 @@ def lower(*args, **kwargs): @api_boundary def eval_shape(*args, **kwargs): p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] + out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, weak_type=x.weak_type) @@ -1001,7 +1000,7 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x): + if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)): return x if mesh is None: msg = ('jax.jit only supports `Sharding`s being passed to' @@ -1110,7 +1109,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves) # Only do this if original in_shardings are unspecified. If it is AUTO, go # via flatten_axis_resources. - if is_unspecified(orig_in_shardings): + if isinstance(orig_in_shardings, UnspecifiedValue): in_shardings_flat = (orig_in_shardings,) * len(in_avals) else: in_shardings_flat = flatten_axis_resources( @@ -1312,8 +1311,7 @@ def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) - if (is_unspecified(orig_out_shardings) or - isinstance(orig_out_shardings, sharding.Sharding)): + if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_avals) else: out_shardings_flat = flatten_axis_resources( @@ -1391,7 +1389,7 @@ def pjit_check_aval_sharding( what_aval: str, allow_uneven_sharding: bool): new_names = [''] * len(shardings) if names is None else names for aval, s, name in zip(flat_avals, shardings, new_names): - if is_unspecified_or_auto(s): + if isinstance(s, (UnspecifiedValue, AUTO)): continue name_str = f' with pytree key path {name}' if name else '' shape = aval.shape @@ -1466,7 +1464,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): else: arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. - is_pmap_sharding = (is_unspecified(rs) or + is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or isinstance(getattr(arg, 'sharding', None), PmapSharding)) if jit_in_l is None: if committed: @@ -1527,15 +1525,15 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if getattr(a, '_committed', True): committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) - resolved_in_shardings = [] + resolved_in_shardings: list[PjitSharding] = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) if hasattr(arg, 'sharding') and arg.sharding is not None else (UNSPECIFIED, False)) - if is_unspecified(pjit_in_s): - if is_unspecified(arg_s): + if isinstance(pjit_in_s, UnspecifiedValue): + if isinstance(arg_s, UnspecifiedValue): resolved_in_shardings.append(arg_s) else: if committed: @@ -1553,7 +1551,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'multiple devices is not supported.') else: if (isinstance(arg, np.ndarray) and - not pjit_in_s.is_fully_replicated and # type: ignore + not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] xb.process_count() > 1): raise ValueError( 'Passing non-trivial shardings for numpy ' @@ -1572,16 +1570,16 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, # either outside or inside jax.jit. - if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore + if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr] raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' - f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore + f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr] f'arg memory kind: {arg_s.memory_kind} for ' f'arg shape: {shaped_abstractify(arg).str_short()}') if (committed and not isinstance(arg_s, PmapSharding) and not op_shardings.are_op_shardings_equal( - pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore + pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr] arg_s._to_xla_hlo_sharding(arg.ndim))): raise ValueError('Sharding passed to pjit does not match the sharding ' 'on the respective arg. ' @@ -1780,8 +1778,8 @@ def pjit_staging_rule(trace, *args, **params): params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) if (params["inline"] and - all(is_unspecified(i) for i in params["in_shardings"]) and - all(is_unspecified(o) for o in params["out_shardings"]) and + all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and + all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): if config.dynamic_shapes.value: @@ -1830,7 +1828,7 @@ def pjit_staging_rule(trace, *args, **params): def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) @@ -1896,8 +1894,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: - arg_shardings = [None if is_unspecified(i) else i for i in in_shardings] - result_shardings = [None if is_unspecified(o) else o for o in out_shardings] + arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings] + result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings] # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. @@ -1990,9 +1988,9 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, def _pjit_batcher_for_sharding( s: sharding.Sharding | UnspecifiedValue, dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): - if is_unspecified(s): + if isinstance(s, UnspecifiedValue): return s - hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore + hlo_s = s._to_xla_hlo_sharding(ndim) if spmd_axis_name is None: if sharding_impls.is_op_sharding_replicated(hlo_s): return s @@ -2004,7 +2002,7 @@ def _pjit_batcher_for_sharding( tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad new_gs = GSPMDSharding( - s._device_assignment, new_op, # type: ignore + s._device_assignment, new_op, _device_list=getattr(s, '_internal_device_list', None)) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: @@ -2107,7 +2105,7 @@ def keep_where(l, should_keep): # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) in_fwd = [ - fwd if is_unspecified(os) and ol is None else None + fwd if isinstance(os, UnspecifiedValue) and ol is None else None for os, ol, fwd in zip( keep_where(out_shardings, known_outs), keep_where(out_layouts, known_outs), in_fwd_primal) @@ -2358,9 +2356,9 @@ def _pjit_pp_rule(eqn, context, settings): del params['inline'] if not any(params['donated_invars']): del params['donated_invars'] - if all(is_unspecified(s) for s in params['in_shardings']): + if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']): del params['in_shardings'] - if all(is_unspecified(s) for s in params['out_shardings']): + if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']): del params['out_shardings'] if all(l is None for l in params['in_layouts']): del params['in_layouts'] @@ -2382,8 +2380,7 @@ def _pjit_pp_rule(eqn, context, settings): def _pjit_state_discharge_rule( in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, **params): - if not (all(map(is_unspecified, in_shardings)) and - all(map(is_unspecified, out_shardings))): + if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)): raise NotImplementedError if not (all(l is None for l in in_layouts) and diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index f54c39efebce..0b9c8b532db6 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -965,21 +965,11 @@ def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) -def is_auto(x): - return isinstance(x, AUTO) - - class UnspecifiedValue: def __repr__(self): return "UnspecifiedValue" UNSPECIFIED = UnspecifiedValue() -def is_unspecified(x): - return isinstance(x, UnspecifiedValue) - -def is_unspecified_or_auto(x): - return is_auto(x) or is_unspecified(x) - MeshAxisName = Any @@ -1022,8 +1012,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): def get_array_mapping( axis_resources: ParsedPartitionSpec | AUTO | UnspecifiedValue ) -> ArrayMappingOrAutoOrUnspecified: - # TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported. - # Don't use `is_auto` here to satisfy pytype and mypy. if isinstance(axis_resources, (AUTO, UnspecifiedValue)): return axis_resources return OrderedDict((axis, i) @@ -1121,7 +1109,7 @@ def prepare_axis_resources(axis_resources, arg_name, new_entries = [] for entry in entries: - if is_unspecified_or_auto(entry) or entry is None: + if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None: new_entries.append(entry) elif isinstance(entry, sharding.Sharding): if isinstance(entry, PmapSharding): @@ -1139,8 +1127,7 @@ def prepare_axis_resources(axis_resources, arg_name, def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue - if (is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, sharding.Sharding)): + if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3a2c375b64db..92c680009c93 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -43,7 +43,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.sharding_impls import is_unspecified_or_auto +from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -649,7 +649,7 @@ def out_info(self): # PyTree of OutInfo out_avals = self._lowering.compile_args["global_out_avals"] out_shardings = self._lowering.compile_args["out_shardings"] return self.out_tree.unflatten( - [OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s) + [OutInfo(o.shape, o.dtype, None if isinstance(s, (UnspecifiedValue, AUTO)) else s) for o, s in zip(out_avals, out_shardings)]) def compile( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 29a1034e51ed..0f7b88005e65 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3523,7 +3523,7 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( s: sharding.Sharding, aval: core.ShapedArray) -> xla_client.HloSharding | None: - if sharding_impls.is_unspecified(s): + if isinstance(s, sharding_impls.UnspecifiedValue): return None return s._to_xla_hlo_sharding(aval.ndim) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 15c9a2cfe49d..f3fd8bac558c 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -40,7 +40,6 @@ ArrayMapping as ArrayMapping, UNSPECIFIED as _UNSPECIFIED, # noqa: F401 array_mapping_to_axis_resources as array_mapping_to_axis_resources, - is_unspecified as _is_unspecified, # noqa: F401 ) from jax._src.sharding_specs import (