Skip to content

Commit

Permalink
[Take 2] Expose .layout on jax.Array. Also add checks in the AOT path…
Browse files Browse the repository at this point in the history
… to make sure that the input Array's layout matches the layout given to jax.jit.

Reverts cd79e71

PiperOrigin-RevId: 618878870
  • Loading branch information
yashk2810 authored and jax authors committed Mar 25, 2024
1 parent b9e699f commit 25d01e9
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 76 deletions.
16 changes: 15 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import layout
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
Expand Down Expand Up @@ -527,6 +529,18 @@ def addressable_shards(self) -> Sequence[Shard]:
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
return out

@property
def layout(self):
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
try:
return layout.SpecifiedLayout(self._pjrt_layout)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
return None
else:
raise

@property
def global_shards(self) -> Sequence[Shard]:
"""Returns list of all `Shard`s of the Array across all devices.
Expand Down Expand Up @@ -637,7 +651,7 @@ def _value(self) -> np.ndarray:
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)


# explicitly set to be unhashable. Same as what device_array.py does.
# explicitly set to be unhashable.
setattr(ArrayImpl, "__hash__", None)
setattr(ArrayImpl, "__array_priority__", 100)

Expand Down
10 changes: 5 additions & 5 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from jax._src import xla_bridge as xb
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import XLACompatibleLayout, LayoutRequest
from jax._src.layout import AutoLayout, SpecifiedLayout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
Expand Down Expand Up @@ -834,10 +834,10 @@ def _to_physical_op_sharding(
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore


def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
if layout is None:
return "default"
if isinstance(layout, LayoutRequest):
if isinstance(layout, AutoLayout):
return "auto"
return layout._to_xla_layout()

Expand All @@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
num_replicas: int = 1,
Expand Down
60 changes: 42 additions & 18 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
from jax._src.layout import SpecifiedLayout, AutoLayout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -1985,13 +1985,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
return False
return True

MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]]
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]


class AllArgsInfo(NamedTuple):
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
in_shardings: Any
in_layouts: Any
debug_info: core.JaxprDebugInfo | None


Expand Down Expand Up @@ -2023,7 +2024,7 @@ def lower_sharding_computation(
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore

all_args_info = AllArgsInfo(global_in_avals, in_shardings,
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
closed_jaxpr.jaxpr.debug_info)

(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
Expand Down Expand Up @@ -2227,8 +2228,6 @@ def lower_mesh_computation(
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts

all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)

assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals
Expand Down Expand Up @@ -2319,7 +2318,7 @@ def lower_mesh_computation(
in_layouts=(None,) * len(global_in_avals),
out_layouts=(None,) * len(global_out_avals),
shape_poly_state=lowering_result.shape_poly_state,
all_args_info=all_args_info)
all_args_info=None)

class MeshComputation(stages.XlaLowering):
_hlo: ir.Module | None
Expand Down Expand Up @@ -2599,7 +2598,7 @@ def _get_layouts_from_executable(
if isinstance(i, SpecifiedLayout):
if i != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
new_in_layouts.append(i)
else:
new_in_layouts.append(x)
Expand All @@ -2610,7 +2609,7 @@ def _get_layouts_from_executable(
if isinstance(o, SpecifiedLayout):
if o != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
new_out_layouts.append(o)
else:
new_out_layouts.append(x)
Expand Down Expand Up @@ -3016,19 +3015,24 @@ def call(self, *args):
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
ref_avals = self.in_avals
in_shardings = self._in_shardings
in_layouts = self._in_layouts
debug_info = None
else:
kept_args = args
ref_avals = self._all_args_info.in_avals
iter_in_shardings = iter(self._in_shardings)
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_shardings)]
iter_in_layouts = iter(self._in_layouts)
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_layouts)]
debug_info = self._all_args_info.debug_info

arg_avals = map(xla.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
check_array_xla_sharding_layout_match(kept_args, in_shardings,
in_layouts, debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable

def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
Expand Down Expand Up @@ -3184,15 +3188,17 @@ def check_device_backend_on_shardings(shardings) -> bool:
return False


def check_gda_or_array_xla_sharding_match(
def check_array_xla_sharding_layout_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
in_xla_layouts: Sequence[SpecifiedLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
from jax._src.array import ArrayImpl
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
jaxpr_debug_info.arg_names)
errors = []
num_errors = 5
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
Expand All @@ -3205,27 +3211,45 @@ def check_gda_or_array_xla_sharding_match(
# Raise memory kind mismatch error even if the arg is uncommitted.
if arg.sharding.memory_kind != xs.memory_kind:
errors.append(
"Got input sharding(s) that compiled object was called with: "
("Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
'sharding'))

if (not db_xs and arg._committed and
not op_shardings.are_op_shardings_equal(
arg.sharding._to_xla_hlo_sharding(arg.ndim),
xs._to_xla_hlo_sharding(arg.ndim))):
errors.append(
"Got input sharding(s) that compiled object was called with: "
("Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
'sharding'))

if (xla_extension_version >= 249 and not db_xs and arg._committed and
arg.layout is not None and xl is not None and arg.layout != xl):
errors.append(
("Got input layout(s) that compiled object was called with: "
f"{arg.layout} and layout(s) the computation was compiled "
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
'layout'))

if errors:
str_errors = '\n'.join(errors[:num_errors])
first_errors, error_kinds = unzip2(errors[:num_errors])
str_errors = '\n'.join(first_errors)
if all(k == 'sharding' for k in error_kinds):
kind_str = r'sharding(s)'
elif all(k == 'layout' for k in error_kinds):
kind_str = 'layout(s)'
else:
kind_str = 'sharding(s) and layout(s)'
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
"Compiled object called with input sharding(s) does not match the "
"sharding(s) the computation was compiled with. "
f"Compiled object called with input {kind_str} does "
f"not match the {kind_str} the computation was "
"compiled with. "
f"Here are {num_mismatch_str}:\n{str_errors}")


Expand Down
29 changes: 6 additions & 23 deletions jax/_src/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from __future__ import annotations

import re

from jax._src.lib import xla_client as xc


Expand All @@ -24,16 +22,10 @@ class Layout:
pass


class XLACompatibleLayout(Layout):

def _to_xla_layout(self) -> str:
raise NotImplementedError("Subclasses should implement this method.")


class SpecifiedLayout(XLACompatibleLayout):
layout: xc.Layout
class SpecifiedLayout(Layout):
layout: xc.PjRtLayout

def __init__(self, layout: xc.Layout):
def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)

Expand All @@ -51,19 +43,10 @@ def __eq__(self, other):
def _to_xla_layout(self) -> str:
return self._layout_str

@property
def _minor_to_major(self):
m = re.search("{([0-9,]*):", str(self))
assert m is not None
m2m_str = m.group(1)
if m2m_str == "":
return ()
return tuple(int(x) for x in m2m_str.split(","))


class LayoutRequest:
class AutoLayout:

def __repr__(self):
return "Request a layout from the compiler"
return "AUTO"

AUTO = LayoutRequest()
AUTO = AutoLayout()
37 changes: 34 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def lower(*args, **kwargs):
try:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
in_layouts_flat = _resolve_in_layouts(
args_flat, in_layouts_flat, in_shardings)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
Expand Down Expand Up @@ -1130,7 +1132,6 @@ def unpack(key):
p("explanation unavailable! please open an issue at https://github.com/google/jax")
return done()


@partial(lu.cache, explain=explain_tracing_cache_miss)
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
del ignored_inline # just for explain_cache_miss
Expand Down Expand Up @@ -1264,6 +1265,35 @@ def pjit_check_aval_sharding(
pjit_p.multiple_results = True


def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
# If device or backend is set, return the default layout. This is because you
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
# which causes error checks to fail. Returning the default layout allows
# this to exist. It's the same for handling shardings.
if pxla.check_device_backend_on_shardings(jit_in_shardings):
return (None,) * len(jit_in_layouts)

resolved_in_layouts = []
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
arg_layout, committed = (
(arg.layout, getattr(arg, '_committed', True))
if getattr(arg, 'layout', None) is not None else (None, False))
if jit_in_l is None:
if committed:
resolved_in_layouts.append(arg_layout)
else:
resolved_in_layouts.append(None)
else:
if committed and arg_layout != jit_in_l:
raise ValueError('Layout passed to jit does not match the layout '
'on the respective arg. '
f'Got pjit layout: {jit_in_l},\n'
f'arg sharding: {arg_layout} for '
f'arg shape: {shaped_abstractify(arg).str_short()}')
resolved_in_layouts.append(jit_in_l)
return tuple(resolved_in_layouts)


def _resolve_in_shardings(
args, pjit_in_shardings: Sequence[PjitSharding],
out_shardings: Sequence[PjitSharding],
Expand Down Expand Up @@ -1387,8 +1417,9 @@ def _pjit_call_impl_python(
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
jaxpr.jaxpr.debug_info)
pxla.check_array_xla_sharding_layout_match(
args, compiled._in_shardings, compiled._in_layouts,
jaxpr.jaxpr.debug_info)
if config.distributed_debug.value:
# Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this
Expand Down
Loading

0 comments on commit 25d01e9

Please sign in to comment.