diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 00e1dc370e7a..99f6df4c2f59 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -87,7 +87,6 @@ def func(...): ... from jax._src import core from jax._src import test_util as jtu -from jax._src.sharding_impls import UNSPECIFIED from jax._src import xla_bridge as xb @@ -104,6 +103,7 @@ class CompatTestData: mlir_module_text: str mlir_module_serialized: bytes xla_call_module_version: int # The version of XlaCallModule to use for testing + nr_devices: int = 1 # The dummy_data is used for getting started for adding a new test and for @@ -187,6 +187,7 @@ def run_one_test(self, func: Callable[..., jax.Array], will fail when the serialization changes. Otherwise, when checking old serializations you can specify what custom calls are expected in the current serialization. + nr_devices: the number of devices for which the data was serialized. """ if not isinstance(data, CompatTestData): raise ValueError(f"Expecting data: CompatTestData but got {data}. " @@ -202,7 +203,7 @@ def run_one_test(self, func: Callable[..., jax.Array], res_run_current = tuple(np.array(a) for a in res_run_current) logging.info("Result of current version run is %s", res_run_current) - serialized, module_str, module_version = self.serialize( + serialized, module_str, module_version, nr_devices = self.serialize( func, data, polymorphic_shapes=polymorphic_shapes, allow_unstable_custom_call_targets=allow_unstable_custom_call_targets) @@ -225,6 +226,7 @@ def run_one_test(self, func: Callable[..., jax.Array], mlir_module_text=r\"\"\"\n{module_str}\"\"\", mlir_module_serialized={serialized!r}, xla_call_module_version={module_version}, + nr_devices={nr_devices}, ) # End paste """ @@ -271,7 +273,7 @@ def serialize(self, func: Callable, data: CompatTestData, *, polymorphic_shapes: Optional[Sequence[str]] = None, allow_unstable_custom_call_targets: Sequence[str] = () - ) -> tuple[bytes, str, int]: + ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: @@ -281,7 +283,8 @@ def serialize(self, custom call targets besides those known as stable. Returns: a tuple with the (a) serialization, (b) the module contents as - a string (for debugging), and (c) the module serialization version. + a string (for debugging), (c) the module serialization version, + (d) the number of devices for which the module was serialized. """ # Use the native exporter, to make sure we get the proper serialization. args_specs = export.args_specs(data.inputs, polymorphic_shapes) @@ -296,7 +299,8 @@ def serialize(self, module_str = str(exported.mlir_module()) serialized = exported.mlir_module_serialized module_version = exported.serialization_version - return serialized, module_str, module_version + nr_devices = exported.nr_devices + return serialized, module_str, module_version, nr_devices def run_serialized(self, data: CompatTestData, polymorphic_shapes: Optional[Sequence[str]] = None): @@ -321,14 +325,15 @@ def _get_vjp(_): in_avals=tuple(in_avals), out_tree=out_tree, out_avals=tuple(out_avals), - in_shardings=(UNSPECIFIED,) * len(in_avals), - out_shardings=(UNSPECIFIED,) * len(out_avals), + in_shardings=(None,) * len(in_avals), + out_shardings=(None,) * len(out_avals), lowering_platforms=(data.platform,), ordered_effects=(), unordered_effects=(), disabled_checks=(), mlir_module_serialized=data.mlir_module_serialized, serialization_version=data.xla_call_module_version, + nr_devices=data.nr_devices, module_kept_var_idx=tuple(range(len(in_avals))), uses_shape_polymorphism=any(not core.is_constant_shape(a.shape) for a in in_avals), diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index 2b4f808083d1..9ac5ec9bcb96 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -122,7 +122,11 @@ def __hash__(self) -> int: _VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 -Sharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] +# The values of input and output sharding from the lowering. +LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] + +# None means unspecified sharding +Sharding = Optional[xla_client.HloSharding] @dataclasses.dataclass(frozen=True) class Exported: @@ -140,9 +144,9 @@ class Exported: out_avals: the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in `in_avals. - in_shardings: the flattened input shardings. Only for the inputs that are - specified in `module_kept_var_idx`. + in_shardings: the flattened input shardings, as long as `in_avals`. out_shardings: the flattened output shardings, as long as `out_avals`. + nr_devices: the number of devices that the module has been lowered for. lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', 'cuda', 'rocm'. See below for the calling convention for when there are multiple lowering platforms. @@ -274,6 +278,7 @@ class Exported: in_shardings: tuple[Sharding, ...] out_shardings: tuple[Sharding, ...] + nr_devices: int lowering_platforms: tuple[str, ...] ordered_effects: tuple[effects.Effect, ...] unordered_effects: tuple[effects.Effect, ...] @@ -521,14 +526,31 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: if version < _VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: ordered_effects = unordered_effects = () + nr_devices = len(lowering.compile_args["device_assignment"]) + def export_sharding(s: LoweringSharding, + aval: core.ShapedArray) -> Sharding: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + + all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], + module_kept_var_idx, + len(args_avals_flat)) + in_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(all_in_shardings, args_avals_flat)) + out_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) return Exported( fun_name=fun_name, in_tree=lowered.in_tree, out_tree=lowered.out_tree, in_avals=tuple(args_avals_flat), out_avals=tuple(out_avals_flat), - in_shardings=tuple(lowering.compile_args["in_shardings"]), - out_shardings=tuple(lowering.compile_args["out_shardings"]), + in_shardings=in_shardings, + out_shardings=out_shardings, + nr_devices=nr_devices, lowering_platforms=actual_lowering_platforms, ordered_effects=ordered_effects, unordered_effects=unordered_effects, @@ -921,16 +943,16 @@ def walk_operations(op): "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") raise ValueError(msg) -def expand_in_shardings(in_shardings: tuple[Sharding, ...], +def expand_in_shardings(in_shardings: Sequence[LoweringSharding], module_kept_var_idx: Sequence[int], - nr_inputs: int) -> tuple[Sharding, ...]: + nr_inputs: int) -> Sequence[LoweringSharding]: """Expands in_shardings with unspecified shardings for inputs not kept. Assumes in_shardings corresponds to module_kept_var_idx. """ assert len(in_shardings) == len(module_kept_var_idx) assert nr_inputs >= len(module_kept_var_idx) - all_in_shardings: list[Sharding] = [sharding_impls.UNSPECIFIED] * nr_inputs + all_in_shardings: list[LoweringSharding] = [sharding_impls.UNSPECIFIED] * nr_inputs for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings): all_in_shardings[idx] = in_s return tuple(all_in_shardings) @@ -938,37 +960,30 @@ def expand_in_shardings(in_shardings: tuple[Sharding, ...], # TODO(yashkatariya, necula): remove this function once we relax the checks # in the jit front-end. def canonical_shardings( + device_assignment: Sequence[jax.Device], in_shardings: Sequence[Sharding], out_shardings: Sequence[Sharding] ) -> tuple[Union[pxla.UnspecifiedValue, Sequence[sharding.XLACompatibleSharding]], Union[pxla.UnspecifiedValue, Sequence[sharding.XLACompatibleSharding]]]: - """Prepares canonical in_ and out_shardings for a jit invocation. + """Prepares canonical in_ and out_shardings for a pjit invocation. The pjit front-end is picky about what in- and out-shardings it accepts, e.g., if all are unspecified then the whole sharding should be the sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are replaced with the replicated sharding. - """ - # Prepare a replicated sharding, search in both the input and output shardings - specified_shardings = [ - s for s in itertools.chain(in_shardings, out_shardings) - if not sharding_impls.is_unspecified(s)] - if specified_shardings: - in_s = specified_shardings[0] # pjit will enforce that all have same devices - assert isinstance(in_s, sharding.XLACompatibleSharding) - replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) - else: - replicated_s = None + Returns: a pair with the canonicalized input and output shardings. + """ + replicated_s = sharding.GSPMDSharding.get_replicated(device_assignment) def canonicalize( ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue, Sequence[sharding.XLACompatibleSharding]]: - if all(sharding_impls.is_unspecified(s) for s in ss): + if all(s is None for s in ss): return sharding_impls.UNSPECIFIED return tuple( - s if not sharding_impls.is_unspecified(s) else replicated_s + sharding.GSPMDSharding(device_assignment, s) if s is not None else replicated_s for s in ss) return (canonicalize(in_shardings), canonicalize(out_shardings)) @@ -976,9 +991,9 @@ def _get_vjp_fun(primal_fun: Callable, *, in_tree: tree_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], - module_kept_var_idx: tuple[int, ...], in_shardings: tuple[Sharding, ...], out_shardings: tuple[Sharding, ...], + nr_devices: int, apply_jit: bool ) -> tuple[Callable, Sequence[core.AbstractValue]]: # Since jax.vjp does not handle kwargs, it is easier to do all the work @@ -1000,30 +1015,30 @@ def flattened_primal_fun_jax(*args_flat): itertools.chain(in_avals, map(lambda a: a.at_least_vspace(), out_avals))) - all_in_shardings = expand_in_shardings(in_shardings, - module_kept_var_idx, len(in_avals)) - vjp_in_shardings, vjp_out_shardings = canonical_shardings( - tuple(itertools.chain(all_in_shardings, out_shardings)), - all_in_shardings) - if apply_jit: + # Prepare a device assignment. For exporting purposes, all it matters + # is the number of devices. + device_assignment = jax.devices(jax.default_backend())[:nr_devices] + assert len(device_assignment) == nr_devices + vjp_in_shardings, vjp_out_shardings = canonical_shardings( + device_assignment, + tuple(itertools.chain(in_shardings, out_shardings)), + in_shardings) return pjit.pjit(fun_vjp_jax, in_shardings=vjp_in_shardings, out_shardings=vjp_out_shardings), vjp_in_avals else: - assert vjp_in_shardings == sharding_impls.UNSPECIFIED - assert vjp_out_shardings == sharding_impls.UNSPECIFIED return fun_vjp_jax, vjp_in_avals def _export_native_vjp(primal_fun, primal: Exported) -> Exported: # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun, in_tree=primal.in_tree, - module_kept_var_idx=primal.module_kept_var_idx, in_avals=primal.in_avals, in_shardings=primal.in_shardings, out_avals=primal.out_avals, out_shardings=primal.out_shardings, + nr_devices=primal.nr_devices, apply_jit=True) return export(fun_vjp_jax, lowering_platforms=primal.lowering_platforms, @@ -1154,13 +1169,24 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, if exported.uses_shape_polymorphism: ctx.module_context.shape_poly_state.uses_dim_vars = True + axis_context = ctx.module_context.axis_context + if isinstance(axis_context, sharding_impls.ShardingContext): + ctx_device_assignment = axis_context.device_assignment + elif isinstance(axis_context, sharding_impls.SPMDAxisContext): + ctx_device_assignment = list(axis_context.mesh.devices.flat) + else: + raise NotImplementedError(type(axis_context)) + if len(ctx_device_assignment) != exported.nr_devices: + raise NotImplementedError( + f"Exported module {exported.fun_name} was lowered for " + f"{exported.nr_devices} devices and is called in a context with " + f"{len(ctx_device_assignment)} devices" + ) + # Apply in_shardings - all_in_shardings = expand_in_shardings(exported.in_shardings, - exported.module_kept_var_idx, - len(args)) args = tuple( - wrap_with_sharding(ctx, exported, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(args, ctx.avals_in, all_in_shardings)) + wrap_with_sharding(ctx, x, x_aval, x_sharding) + for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings)) submodule = ir.Module.parse(exported.mlir_module()) symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called @@ -1251,7 +1277,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra exported.out_avals, ctx.avals_out)) # Apply out_shardings results = tuple( - wrap_with_sharding(ctx, exported, x, x_aval, x_sharding) + wrap_with_sharding(ctx, x, x_aval, x_sharding) for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings) ) return results @@ -1264,27 +1290,10 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra mlir.register_lowering(call_exported_p, _call_exported_lowering) def wrap_with_sharding(ctx: mlir.LoweringRuleContext, - exported: Exported, x: ir.Value, x_aval: core.AbstractValue, x_sharding: Sharding) -> ir.Value: - if sharding_impls.is_unspecified(x_sharding): + if x_sharding is None: return x - axis_context = ctx.module_context.axis_context - if isinstance(axis_context, sharding_impls.ShardingContext): - ctx_device_assignment = axis_context.device_assignment - elif isinstance(axis_context, sharding_impls.SPMDAxisContext): - ctx_device_assignment = list(axis_context.mesh.devices.flat) - else: - raise NotImplementedError(type(axis_context)) - assert isinstance(x_sharding, sharding_impls.XLACompatibleSharding) - sharding_device_assignment = x_sharding._device_assignment - if len(ctx_device_assignment) != len(sharding_device_assignment): - raise NotImplementedError( - f"Exported module {exported.fun_name} was lowered for " - f"{len(sharding_device_assignment)} devices and is called in a context with " - f"{len(ctx_device_assignment)} devices" - ) return mlir.wrap_with_sharding_op( - ctx, x, x_aval, - x_sharding._to_xla_hlo_sharding(x_aval.ndim).to_proto()) + ctx, x, x_aval, x_sharding.to_proto()) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 63756c2150c4..5dfad2df0096 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -512,13 +512,13 @@ def run_fun_tf(self, def get_vjp_fun(self) -> tuple[Callable, Sequence[core.AbstractValue]]: return export._get_vjp_fun(self.fun_jax, - in_tree=self.exported.in_tree, - module_kept_var_idx=self.exported.module_kept_var_idx, - in_avals=self.exported.in_avals, - in_shardings=self.exported.in_shardings, - out_avals=self.exported.out_avals, - out_shardings=self.exported.out_shardings, - apply_jit=True) + in_tree=self.exported.in_tree, + in_avals=self.exported.in_avals, + in_shardings=self.exported.in_shardings, + out_avals=self.exported.out_avals, + out_shardings=self.exported.out_shardings, + nr_devices=self.exported.nr_devices, + apply_jit=True) class GraphSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, @@ -584,13 +584,13 @@ def get_vjp_fun(self) -> tuple[Callable, # except we use unspecified shardings, and we do not apply a jit on the # VJP. This matches the older behavior of jax2tf for graph serialization. return export._get_vjp_fun(self.fun_jax, - in_tree=self.in_tree, - module_kept_var_idx=tuple(range(len(self.args_avals_flat))), - in_avals=self.args_avals_flat, - in_shardings=(sharding_impls.UNSPECIFIED,) * len(self.args_avals_flat), - out_avals=self.outs_avals, - out_shardings=(sharding_impls.UNSPECIFIED,) * len(self.outs_avals), - apply_jit=False) + in_tree=self.in_tree, + in_avals=self.args_avals_flat, + in_shardings=(None,) * len(self.args_avals_flat), + out_avals=self.outs_avals, + out_shardings=(None,) * len(self.outs_avals), + nr_devices=1, # Does not matter for unspecified shardings + apply_jit=False) def dtype_of_val(val: TfVal) -> DType: @@ -890,10 +890,13 @@ def _convert_value(val, aval): # Do not apply XlaSharding for REPLICATED, on inputs and outputs. # This is an agreed convention, and also improves usability under TF eager. # See b/255511660. - if exported.in_shardings is not None: - args_flat_tf = tuple( - map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), - kept_args_flat_tf, kept_args_avals, exported.in_shardings)) + kept_in_shardings = [] + for i in exported.module_kept_var_idx: + kept_in_shardings.append(exported.in_shardings[i]) + args_flat_tf = tuple( + map(partial(_shard_value, + skip_replicated_sharding=tf.executing_eagerly()), + kept_args_flat_tf, kept_in_shardings)) res = tfxla.call_module(args_flat_tf, **call_module_attrs) # TODO(b/278940799): Replace the TF v1 API with public TF2 API. # Add the custom call tf.function into the default graph, so those functions @@ -904,10 +907,9 @@ def _convert_value(val, aval): concrete_fn._inference_function ) - if exported.out_shardings is not None: - res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), - res, exported.out_avals, exported.out_shardings)) - + res = list(map(partial(_shard_value, + skip_replicated_sharding=tf.executing_eagerly()), + res, exported.out_shardings)) res = tuple(map(_convert_value, res, exported.out_avals)) return res @@ -3405,17 +3407,21 @@ def split_to_logical_devices(tensor: TfVal, return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) +def _xla_compatible_sharding_to_hlo_sharding( + s: sharding.XLACompatibleSharding, + aval: core.ShapedArray) -> Optional[xla_client.HloSharding]: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + def _shard_value(val: TfVal, - aval: core.ShapedArray, - sd: sharding.XLACompatibleSharding, *, + sd: Optional[xla_client.HloSharding], *, skip_replicated_sharding: bool) -> TfVal: """Apply sharding to a TfVal.""" - if sharding_impls.is_unspecified(sd): + if sd is None: return val - sharding_proto: xla_client.OpSharding = cast( - xla_client.OpSharding, sd._to_xla_hlo_sharding(aval.ndim).to_proto()) # type: ignore - + sharding_proto = sd.to_proto() if (skip_replicated_sharding and op_shardings.is_op_sharding_replicated(sharding_proto)): return val @@ -3465,17 +3471,21 @@ def _pjit(*args: TfVal, _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars # Apply sharding annotation to the arguments + in_hlo_shardings: Sequence[Optional[xla_client.HloSharding]] = map( + _xla_compatible_sharding_to_hlo_sharding, in_shardings, _in_avals) sharded_args: Sequence[TfVal] = tuple( map(partial(_shard_value, skip_replicated_sharding=not _thread_local_state.enable_xla), - args, _in_avals, in_shardings)) + args, in_hlo_shardings)) results = _interpret_jaxpr(jaxpr, *sharded_args, extra_name_stack=util.wrap_name(name, "pjit"), fresh_constant_cache=False) + out_hlo_shardings: Sequence[Optional[xla_client.HloSharding]] = map( + _xla_compatible_sharding_to_hlo_sharding, out_shardings, _out_aval) sharded_results: Sequence[TfVal] = tuple( map(partial(_shard_value, skip_replicated_sharding=not _thread_local_state.enable_xla), - results, _out_aval, out_shardings)) + results, out_hlo_shardings)) return tuple(sharded_results) @@ -3483,12 +3493,14 @@ def _pjit(*args: TfVal, def _pjit_sharding_constraint(arg: TfVal, *, - sharding: sharding.NamedSharding, + sharding: sharding.XLACompatibleSharding, resource_env: maps.ResourceEnv, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, **kwargs) -> TfVal: - return _shard_value(arg, _in_avals[0], sharding, skip_replicated_sharding=False) + hlo_sharding = _xla_compatible_sharding_to_hlo_sharding(sharding, _in_avals[0]) + return _shard_value(arg, hlo_sharding, + skip_replicated_sharding=False) tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index d30116b82372..d866e0c2b89e 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -585,7 +585,7 @@ def func(x): def test_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: - self.skipTest("Test runs only on TPU with at least 2 devices") + self.skipTest("Test runs only on TPU with at least 2 devices") # Must use exactly 2 devices for expected outputs from ppermute devices = jax.devices()[:2] diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py b/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py index 69072d23f03b..f2d8be3b958a 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Sharding.py @@ -45,4 +45,5 @@ """, mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01\x1b\x05\x01\x05\x01\x03\x05\x03\x0b\x07\t\x0b\r\x0f\x03\x9d\x81\r\x01K\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0bS\x0b\x0b\x0b\x0b\x17\x0b\x13\x0b33\x0b\x0bS\x1b\x0b\x0b\x0f\x0b\x17SS\x13\x0b\x037\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x8f\x0b\x03\r\x17\x17\x07\x07\x17\x17\x02\xb6\x04\x1f\x1d1%\x05\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x1d3%\x05#\x03\x13\x05O\x07M\tY\x0bK\rQ\x0f[\x11K\x13K\x15K\x05%\x05'\x05)\x05+\x17'\x02\x02\x01\x05-\x03\x03\x19+\x05/\x03\x0b\x1d_\x1fS!k\x19q#s\x03\x0b\x1dU\x1fS!U\x19W#w\x051\x053\x03\x13\x05O\x07M\ty\x0bK\rQ\x0f]\x11K\x13K\x15K\x03\x059{;}\x055\x057\x1d?A\x059\x17'\x12\x05\x01\x03\x13\x05O\x07M\tY\x0bK\rQ\x0f]\x11K\x13K\x15K\x03\x13\x05O\x07M\t\x7f\x0bK\rQ\x0f[\x11K\x13K\x15K\x03\x03IW\x05;\x03\x01\x1d=\x0b\x03\x05\x01#\t\x03\x03u\x1d?\x1dA\x1dC\x1dE\x03\x03a\r\x05cegi\x1dG\x1dI\x1d\x1b\x1dK\x03\x03m\r\x03oM\x1dM\x1dO\x1dQ\r\x01\x1dS\x1dU\x13\x07\x05\x1f\x0bA\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dW)\x05\t\x11\x05)\x05\x05\x11\x05\t\x1d\x11\x03\x01\x03\x01)\x05\t\t\x07\x04\xd3\x05\x01\x11\x01)\x07\x03\x01\t\x05\x11\x01-\x05\x03\x05\x0b\x03\x01\x01\x0b\x07\x03G\x03\x01\x03\x01\x07\x04\x01\x03\x03\x05\x11\x03/\x05\x03\x11#\x03\x01\x01\x03\x07\x03\x1b\x03\x01\x03\x01\x03\x07\x17\x1b\x03\x01\x03\x03\x03\x07\x175\x03\x03\x03\x05\t\x07=7\x03\x03\x03\x07\x03\x07\x17C\x03\x03\x03\t\x03\x07\x17E\x03\x01\x03\x0b\x03\x07\x03\x1b\x03\x01\x03\r\x07\x04\x03\x03\x0f\x06\x03\x01\x05\x01\x00\x82\x13Y++\x11\x0f\x0b!\x1b\x11\x1b\x13'\x13\x11\x03\x0f\xa3)\x17\x9e\x02\x1e\x06\x19\x83\x1f\x15\x1d\x15\x13\x1f/!\x1d!)#\x1f\x19\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.sharding\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit_wrapped\x00jit(wrapped)/jit(main)/pjit[in_shardings=(GSPMDSharding({devices=[2,1]0,1}),) out_shardings=(GSPMDSharding({devices=[2,1]0,1}),) resource_env=ResourceEnv(Mesh(device_ids=array([0, 1]), axis_names=('a',)), ()) donated_invars=(False,) name=wrapped in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit(wrapped)/jit(main)/pjit(wrapped)/shard_map[mesh=Mesh(device_ids=array([0, 1]), axis_names=('a',)) in_names=({0: ('a',)},) out_names=({0: ('a',)},) check_rep=True]\x00channel_id\x00source_target_pairs\x00jit(wrapped)/jit(main)/pjit(wrapped)/ppermute[axis_name=a perm=((0, 1), (1, 0))]\x00callee\x00\x00wrapped\x00Sharding\x00{devices=[2,1]0,1}\x00{manual}\x00jax.arg_info\x00args[0]\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00SPMDFullToShardShape\x00SPMDShardToFullShape\x00", xla_call_module_version=4, + nr_devices=2, ) # End paste diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index d054d58060bf..7f7d13f0c92e 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -119,7 +119,8 @@ def serialize( options=tf.saved_model.SaveOptions(experimental_custom_gradients=False), ) serialized = serialize_directory(saved_model_dir) - return serialized, module_str, module_version + nr_devices = 1 + return serialized, module_str, module_version, nr_devices def run_serialized( self,