Skip to content

Commit

Permalink
Merge pull request #18761 from gnecula:export_sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587330573
  • Loading branch information
jax authors committed Dec 2, 2023
2 parents b822801 + 3eb3e2d commit b51b80e
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 98 deletions.
19 changes: 12 additions & 7 deletions jax/_src/internal_test_util/export_back_compat_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def func(...): ...

from jax._src import core
from jax._src import test_util as jtu
from jax._src.sharding_impls import UNSPECIFIED
from jax._src import xla_bridge as xb


Expand All @@ -104,6 +103,7 @@ class CompatTestData:
mlir_module_text: str
mlir_module_serialized: bytes
xla_call_module_version: int # The version of XlaCallModule to use for testing
nr_devices: int = 1


# The dummy_data is used for getting started for adding a new test and for
Expand Down Expand Up @@ -187,6 +187,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
will fail when the serialization changes. Otherwise, when checking old
serializations you can specify what custom calls are expected in the
current serialization.
nr_devices: the number of devices for which the data was serialized.
"""
if not isinstance(data, CompatTestData):
raise ValueError(f"Expecting data: CompatTestData but got {data}. "
Expand All @@ -202,7 +203,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
res_run_current = tuple(np.array(a) for a in res_run_current)
logging.info("Result of current version run is %s", res_run_current)

serialized, module_str, module_version = self.serialize(
serialized, module_str, module_version, nr_devices = self.serialize(
func, data,
polymorphic_shapes=polymorphic_shapes,
allow_unstable_custom_call_targets=allow_unstable_custom_call_targets)
Expand All @@ -225,6 +226,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
mlir_module_text=r\"\"\"\n{module_str}\"\"\",
mlir_module_serialized={serialized!r},
xla_call_module_version={module_version},
nr_devices={nr_devices},
) # End paste
"""
Expand Down Expand Up @@ -271,7 +273,7 @@ def serialize(self,
func: Callable, data: CompatTestData, *,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_unstable_custom_call_targets: Sequence[str] = ()
) -> tuple[bytes, str, int]:
) -> tuple[bytes, str, int, int]:
"""Serializes the test function.
Args:
Expand All @@ -281,7 +283,8 @@ def serialize(self,
custom call targets besides those known as stable.
Returns: a tuple with the (a) serialization, (b) the module contents as
a string (for debugging), and (c) the module serialization version.
a string (for debugging), (c) the module serialization version,
(d) the number of devices for which the module was serialized.
"""
# Use the native exporter, to make sure we get the proper serialization.
args_specs = export.args_specs(data.inputs, polymorphic_shapes)
Expand All @@ -296,7 +299,8 @@ def serialize(self,
module_str = str(exported.mlir_module())
serialized = exported.mlir_module_serialized
module_version = exported.serialization_version
return serialized, module_str, module_version
nr_devices = exported.nr_devices
return serialized, module_str, module_version, nr_devices

def run_serialized(self, data: CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None):
Expand All @@ -321,14 +325,15 @@ def _get_vjp(_):
in_avals=tuple(in_avals),
out_tree=out_tree,
out_avals=tuple(out_avals),
in_shardings=(UNSPECIFIED,) * len(in_avals),
out_shardings=(UNSPECIFIED,) * len(out_avals),
in_shardings=(None,) * len(in_avals),
out_shardings=(None,) * len(out_avals),
lowering_platforms=(data.platform,),
ordered_effects=(),
unordered_effects=(),
disabled_checks=(),
mlir_module_serialized=data.mlir_module_serialized,
serialization_version=data.xla_call_module_version,
nr_devices=data.nr_devices,
module_kept_var_idx=tuple(range(len(in_avals))),
uses_shape_polymorphism=any(not core.is_constant_shape(a.shape)
for a in in_avals),
Expand Down
123 changes: 66 additions & 57 deletions jax/experimental/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def __hash__(self) -> int:
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9

Sharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]

# None means unspecified sharding
Sharding = Optional[xla_client.HloSharding]

@dataclasses.dataclass(frozen=True)
class Exported:
Expand All @@ -140,9 +144,9 @@ class Exported:
out_avals: the flat tuple of output abstract values. May contain dimension
expressions in the shapes, with dimension variables among those in
`in_avals.
in_shardings: the flattened input shardings. Only for the inputs that are
specified in `module_kept_var_idx`.
in_shardings: the flattened input shardings, as long as `in_avals`.
out_shardings: the flattened output shardings, as long as `out_avals`.
nr_devices: the number of devices that the module has been lowered for.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See below for the calling convention for when
there are multiple lowering platforms.
Expand Down Expand Up @@ -274,6 +278,7 @@ class Exported:

in_shardings: tuple[Sharding, ...]
out_shardings: tuple[Sharding, ...]
nr_devices: int
lowering_platforms: tuple[str, ...]
ordered_effects: tuple[effects.Effect, ...]
unordered_effects: tuple[effects.Effect, ...]
Expand Down Expand Up @@ -521,14 +526,31 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
ordered_effects = unordered_effects = ()

nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> Sharding:
if sharding_impls.is_unspecified(s):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]

all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,
len(args_avals_flat))
in_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(all_in_shardings, args_avals_flat))
out_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
return Exported(
fun_name=fun_name,
in_tree=lowered.in_tree,
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),
out_avals=tuple(out_avals_flat),
in_shardings=tuple(lowering.compile_args["in_shardings"]),
out_shardings=tuple(lowering.compile_args["out_shardings"]),
in_shardings=in_shardings,
out_shardings=out_shardings,
nr_devices=nr_devices,
lowering_platforms=actual_lowering_platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
Expand Down Expand Up @@ -921,64 +943,57 @@ def walk_operations(op):
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)

def expand_in_shardings(in_shardings: tuple[Sharding, ...],
def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
module_kept_var_idx: Sequence[int],
nr_inputs: int) -> tuple[Sharding, ...]:
nr_inputs: int) -> Sequence[LoweringSharding]:
"""Expands in_shardings with unspecified shardings for inputs not kept.
Assumes in_shardings corresponds to module_kept_var_idx.
"""
assert len(in_shardings) == len(module_kept_var_idx)
assert nr_inputs >= len(module_kept_var_idx)
all_in_shardings: list[Sharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
all_in_shardings[idx] = in_s
return tuple(all_in_shardings)

# TODO(yashkatariya, necula): remove this function once we relax the checks
# in the jit front-end.
def canonical_shardings(
device_assignment: Sequence[jax.Device],
in_shardings: Sequence[Sharding],
out_shardings: Sequence[Sharding]
) -> tuple[Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]],
Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]]]:
"""Prepares canonical in_ and out_shardings for a jit invocation.
"""Prepares canonical in_ and out_shardings for a pjit invocation.
The pjit front-end is picky about what in- and out-shardings it accepts,
e.g., if all are unspecified then the whole sharding should be the
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
replaced with the replicated sharding.
"""
# Prepare a replicated sharding, search in both the input and output shardings
specified_shardings = [
s for s in itertools.chain(in_shardings, out_shardings)
if not sharding_impls.is_unspecified(s)]
if specified_shardings:
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
else:
replicated_s = None
Returns: a pair with the canonicalized input and output shardings.
"""
replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment)
def canonicalize(
ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]]:
if all(sharding_impls.is_unspecified(s) for s in ss):
if all(s is None for s in ss):
return sharding_impls.UNSPECIFIED
return tuple(
s if not sharding_impls.is_unspecified(s) else replicated_s
sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s
for s in ss)
return (canonicalize(in_shardings), canonicalize(out_shardings))

def _get_vjp_fun(primal_fun: Callable, *,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
module_kept_var_idx: tuple[int, ...],
in_shardings: tuple[Sharding, ...],
out_shardings: tuple[Sharding, ...],
nr_devices: int,
apply_jit: bool
) -> tuple[Callable, Sequence[core.AbstractValue]]:
# Since jax.vjp does not handle kwargs, it is easier to do all the work
Expand All @@ -1000,30 +1015,30 @@ def flattened_primal_fun_jax(*args_flat):
itertools.chain(in_avals,
map(lambda a: a.at_least_vspace(), out_avals)))

all_in_shardings = expand_in_shardings(in_shardings,
module_kept_var_idx, len(in_avals))
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
tuple(itertools.chain(all_in_shardings, out_shardings)),
all_in_shardings)

if apply_jit:
# Prepare a device assignment. For exporting purposes, all it matters
# is the number of devices.
device_assignment = jax.devices(jax.default_backend())[:nr_devices]
assert len(device_assignment) == nr_devices
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
device_assignment,
tuple(itertools.chain(in_shardings, out_shardings)),
in_shardings)
return pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_out_shardings), vjp_in_avals
else:
assert vjp_in_shardings == sharding_impls.UNSPECIFIED
assert vjp_out_shardings == sharding_impls.UNSPECIFIED
return fun_vjp_jax, vjp_in_avals

def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
in_tree=primal.in_tree,
module_kept_var_idx=primal.module_kept_var_idx,
in_avals=primal.in_avals,
in_shardings=primal.in_shardings,
out_avals=primal.out_avals,
out_shardings=primal.out_shardings,
nr_devices=primal.nr_devices,
apply_jit=True)
return export(fun_vjp_jax,
lowering_platforms=primal.lowering_platforms,
Expand Down Expand Up @@ -1154,13 +1169,24 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True

axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
ctx_device_assignment = axis_context.device_assignment
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
ctx_device_assignment = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))
if len(ctx_device_assignment) != exported.nr_devices:
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{len(ctx_device_assignment)} devices"
)

# Apply in_shardings
all_in_shardings = expand_in_shardings(exported.in_shardings,
exported.module_kept_var_idx,
len(args))
args = tuple(
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, all_in_shardings))
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
submodule = ir.Module.parse(exported.mlir_module())
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
Expand Down Expand Up @@ -1251,7 +1277,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
exported.out_avals, ctx.avals_out))
# Apply out_shardings
results = tuple(
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
)
return results
Expand All @@ -1264,27 +1290,10 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
mlir.register_lowering(call_exported_p, _call_exported_lowering)

def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
exported: Exported,
x: ir.Value,
x_aval: core.AbstractValue,
x_sharding: Sharding) -> ir.Value:
if sharding_impls.is_unspecified(x_sharding):
if x_sharding is None:
return x
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
ctx_device_assignment = axis_context.device_assignment
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
ctx_device_assignment = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))
assert isinstance(x_sharding, sharding_impls.XLACompatibleSharding)
sharding_device_assignment = x_sharding._device_assignment
if len(ctx_device_assignment) != len(sharding_device_assignment):
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{len(sharding_device_assignment)} devices and is called in a context with "
f"{len(ctx_device_assignment)} devices"
)
return mlir.wrap_with_sharding_op(
ctx, x, x_aval,
x_sharding._to_xla_hlo_sharding(x_aval.ndim).to_proto())
ctx, x, x_aval, x_sharding.to_proto())
Loading

0 comments on commit b51b80e

Please sign in to comment.