diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 19355e505641..e688b9c2d322 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -52,8 +52,6 @@ logger = logging.getLogger(__name__) -allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects - ### Policies def everything_saveable(*_, **__) -> bool: @@ -483,11 +481,11 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp -allowed_effects.add_type(lax_internal.InOutFeedEffect) +effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect) def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars - disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects) + disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( 'Effects not supported in partial-eval of `checkpoint`/`remat`: ' diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 12ddffa90d35..ef9981ee56ef 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -359,6 +359,7 @@ class OrderedIOEffect(effects.Effect): effects.control_flow_allowed_effects.add_type(IOEffect) effects.control_flow_allowed_effects.add_type(OrderedIOEffect) effects.ordered_effects.add_type(OrderedIOEffect) +effects.shardable_ordered_effects.add_type(OrderedIOEffect) def io_callback_impl( diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index b1b67b5991c6..8d89281cf9e3 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -49,8 +49,6 @@ map = safe_map zip = safe_zip -allowed_effects: effects.EffectTypeSet = ( - effects.custom_derivatives_allowed_effects) ### util @@ -416,7 +414,7 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args): yield outs, tuple(todo) # Ensure the aux output is immutable -allowed_effects.add_type(lax.InOutFeedEffect) +effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') @@ -424,7 +422,7 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts, symbolic_zeros): # TODO(mattjj): could do more checking here... del in_avals, jvp_jaxpr_thunk, num_consts - disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects) + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_jvp`: {disallowed_effects}') @@ -817,7 +815,7 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): return core.jaxpr_as_fun(fun_jaxpr)(*args) def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): - disallowed_effects = allowed_effects.filter_not_in(fun_jaxpr.effects) + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_vjp`: {disallowed_effects}') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e7c39b9538fd..d1b6cdd072fc 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -28,6 +28,7 @@ import numpy as np +import jax from jax._src import basearray from jax._src import core from jax._src import dtypes @@ -187,57 +188,48 @@ def simple_impl(prim): RuntimeToken = Any class RuntimeTokenSet(threading.local): - tokens: dict[core.Effect, tuple[RuntimeToken, Device]] - output_tokens: dict[Device, RuntimeToken] + """See docstring for effect.py module for the calling convention for tokens.""" + + # For each ordered effect, the token returned by the last dispatched + # computation, sharded over the devices in that computation. + current_tokens: dict[core.Effect, jax.Array] + + # For each device, the runtime token returned by the last dispatched + # computation on that device. output_runtime_tokens: dict[Device, RuntimeToken] def __init__(self): - self.tokens = {} - # TODO(sharadmv): remove redundant output token dictionary when minimum - # jaxlib version is bumped to 0.3.16. - self.output_tokens = {} + self.current_tokens = {} self.output_runtime_tokens = {} - def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken: - s = SingleDeviceSharding(device) - if eff not in self.tokens: - inp = np.zeros(0, np.bool_) - indices = tuple( - s.addressable_devices_indices_map(inp.shape).values()) - out = pxla.shard_args([device], [indices], [s], [inp]) - self.tokens[eff] = out, device - elif self.tokens[eff][1] != device: - (old_token,), _ = self.tokens[eff] - indices = tuple( - s.addressable_devices_indices_map((0,)).values()) - out = pxla.shard_args([device], [indices], [s], [old_token]) - self.tokens[eff] = out, device - return self.tokens[eff][0] - - def update_token(self, eff: core.Effect, token: RuntimeToken): - self.tokens[eff] = token, self.tokens[eff][1] - - def set_output_token(self, device: Device, token: RuntimeToken): - # We're free to clobber the previous output token because on each - # device we have a total ordering of computations. Only the token - # from the latest computation matters. If this weren't the case - # we'd need to store a set of output tokens. - self.output_tokens[device] = token + def get_token_input(self, eff: core.Effect, + devices: list[Device]) -> jax.Array: + tok = self.current_tokens.get(eff, np.zeros(0, np.bool_)) + s = NamedSharding(pxla.Mesh(devices, axis_names=["dev"]), + PartitionSpec([])) + s = jax.sharding.GSPMDSharding.get_replicated(devices) + indices = tuple( + s.addressable_devices_indices_map(tok.shape).values()) + sharded_tok = pxla.shard_args(devices, [indices], [s], [tok])[0] + self.current_tokens[eff] = sharded_tok + return sharded_tok + + def set_token_result(self, eff: core.Effect, token: jax.Array): + self.current_tokens[eff] = token def set_output_runtime_token(self, device: Device, token: RuntimeToken): - # TODO(sharadmv): remove this method when minimum jaxlib version is bumped + # We're free to clobber the previous output token because on each + # device we have a total ordering of computations. Only the token + # from the latest computation matters. self.output_runtime_tokens[device] = token def clear(self): - self.tokens = {} - self.output_tokens = {} + self.current_tokens = {} self.output_runtime_tokens = {} def block_until_ready(self): - for token, _ in self.tokens.values(): - token[0].block_until_ready() - for token in self.output_tokens.values(): - token[0].block_until_ready() + for token in self.current_tokens.values(): + token.block_until_ready() for token in self.output_runtime_tokens.values(): token.block_until_ready() self.clear() diff --git a/jax/_src/effects.py b/jax/_src/effects.py index 75e8cb8dd029..b5dcfc48a5fa 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -11,11 +11,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""JAX effects. + +JAX uses effects to describe computations that may have side-effects. Effects +are associated with JAX primitive instances and Jaxprs. + +A primitive instance with an effect will be protected from dead-code elimination +even if its result is unused. + +A special class of effects are the **ordered** effects +(members of `effects.ordered_effects`). +The lowering of a computation with ordered effects will have one additional +input and one additional output for each ordered effect. These appear before +the regular inputs/outputs, and are of type `i1[0]`. These tokens +are threaded through the instructions with ordered effects to ensure that the +compiler will not eliminate, replicate, or reordered the corresponding +instructions. + +To ensure the ordering across multiple computations we maintain a +per-thread set of the tokens returned by the last dispatched computation. There +is one token per ordered effect, and it may be sharded over the devices +used by the last dispatched computation. Upon dispatching a +new computation with ordered effects we take the current token, we shard it +on the devices for the computation to be dispatched and we pass it as an input. +Then we update the current token to refer to the token output of +the dispatched computation. + +When we have ordered effects, we also use the current token to implement +`jax.barrier` which waits until the current tokens are ready. + +The implementation of `jax.barrier` for unordered effects is a bit different, +because for these effects we do not thread tokens in and out of dispatched +computation. Instead, we use a `RuntimeToken`, which is an object returned when +dispatching a computation and on which we can block until is ready. We store +for each thread the `RuntimeToken` returned by the last dispatched computation. + +For more details, see the design note: +https://jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html. +""" + from __future__ import annotations from collections.abc import Iterable from typing import Any + class Effect: """A generic side-effect.""" @@ -66,6 +106,14 @@ def filter_not_in(self, effects: Iterable[Effect]) -> list[Effect]: no_effects: Effects = set() ordered_effects: EffectTypeSet = EffectTypeSet() + +# By default, ordered effects are not allowed in multi-device computations, +# because we cannot ensure a total order. Optionally, an effect can be +# declared as shardable, which means that effects will appear in program order +# but for a given program point we may see several side effects on the +# participating devices, and there is no guarantee of their relative ordering. +shardable_ordered_effects: EffectTypeSet = EffectTypeSet() + lowerable_effects: EffectTypeSet = EffectTypeSet() control_flow_allowed_effects: EffectTypeSet = EffectTypeSet() custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet() diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3c6f09bd6c3b..da89294d4117 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1109,38 +1109,27 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, self.has_unordered_effects = bool(unordered_effects) self.ordered_effects = ordered_effects self._local_devices = self.xla_executable.local_devices() - if ordered_effects: - assert len(self._local_devices) == 1 self.keepalive = keepalive self.has_host_callbacks = has_host_callbacks self.kept_var_idx = kept_var_idx def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: - device, = self._local_devices - tokens = [list(dispatch.runtime_tokens.get_token(eff, device)) - for eff in self.ordered_effects] + tokens = [ + dispatch.runtime_tokens.get_token_input(eff, self._local_devices) + for eff in self.ordered_effects] input_bufs = [*tokens, *input_bufs] return input_bufs def _handle_token_bufs(self, token_bufs, sharded_token): + # token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned + # token buffer (as a singleton list). + # sharded_token: ShardedToken, containing the RuntimeTokens for each device for i, device in enumerate(self._local_devices): dispatch.runtime_tokens.set_output_runtime_token( device, sharded_token.get_token(i)) for eff, token_buf in zip(self.ordered_effects, token_bufs): - dispatch.runtime_tokens.update_token(eff, token_buf) - - def _call_with_tokens(self, input_bufs): - input_bufs = self._add_tokens_to_inputs(input_bufs) - out_bufs, sharded_token = ( - self.xla_executable.execute_sharded_on_local_devices_with_tokens( - input_bufs - ) - ) - num_output_tokens = len(self.ordered_effects) - token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens]) - self._handle_token_bufs(token_bufs, sharded_token) - return out_bufs + dispatch.runtime_tokens.set_token_result(eff, token_buf[0]) @profiler.annotate_function def __call__(self, *args): @@ -1152,10 +1141,10 @@ def __call__(self, *args): results = self.xla_executable.execute_sharded( input_bufs, with_tokens=True ) - self._handle_token_bufs( - results.disassemble_prefix_into_single_device_arrays( - len(self.ordered_effects)), - results.consume_token()) + result_token_bufs = results.disassemble_prefix_into_single_device_arrays( + len(self.ordered_effects)) + sharded_runtime_token = results.consume_token() + self._handle_token_bufs(result_token_bufs, sharded_runtime_token) else: results = self.xla_executable.execute_sharded(input_bufs) if dispatch.needs_check_special(): @@ -1833,9 +1822,13 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, module_name = f"{api_name}_{fun_name}" if len(device_assignment) > 1: - if any(effects.ordered_effects.contains(eff) for eff - in closed_jaxpr.effects): - raise ValueError("Ordered effects are not supported for more than 1 device.") + unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects) + unsupported_effects = effects.shardable_ordered_effects.filter_not_in( + unsupported_effects) + if len(unsupported_effects) > 0: + raise ValueError( + "The following ordered effects are not supported for " + f"more than 1 device: {unsupported_effects}") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) with dispatch.log_elapsed_time( diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 54b4c778cf3c..4b0ee80a8dda 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -25,7 +25,6 @@ _custom_linear_solve_impl, linear_solve_p) -from jax._src.lax.control_flow.common import allowed_effects # Private utilities used elsewhere in JAX # TODO(sharadmv): lift them into a more common place from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr, diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index cf174a027ae6..52ab4871e720 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -21,7 +21,7 @@ from jax._src import core from jax._src import linear_util as lu from jax._src.lax import lax -from jax._src.effects import control_flow_allowed_effects as allowed_effects +from jax._src import effects from jax._src import ad_util from jax._src import state from jax._src import util @@ -33,7 +33,7 @@ map, unsafe_map = safe_map, map -allowed_effects.add_type(lax.InOutFeedEffect) +effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect) def _abstractify(x): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index f7673d81c3fb..b1debcf2b370 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -55,7 +55,6 @@ _make_closed_jaxpr, _prune_zeros, _typecheck_param, - allowed_effects, ) map, unsafe_map = safe_map, map @@ -144,7 +143,7 @@ def switch(index, branches, *operands): out_trees[0], jaxprs[0].out_avals, out_tree, jaxpr.out_avals) joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) - disallowed_effects = allowed_effects.filter_not_in(joined_effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') @@ -253,7 +252,7 @@ def cond(pred, true_fun, false_fun, *operands): out_tree, true_jaxpr.out_avals, false_out_tree, false_jaxpr.out_avals) joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) - disallowed_effects = allowed_effects.filter_not_in(joined_effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') @@ -325,7 +324,7 @@ def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects: def _cond_abstract_eval(*avals, branches, **_): joined_effects = _join_cond_effects(branches) - disallowed_effects = allowed_effects.filter_not_in(joined_effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') @@ -765,7 +764,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear): jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals) joined_effects = _join_cond_effects(branches) - disallowed_effects = allowed_effects.filter_not_in(joined_effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 1cb2c1420573..bf118b83c817 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -58,7 +58,7 @@ from jax._src.lax.control_flow.common import ( _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr, - _make_closed_jaxpr, _prune_zeros, _typecheck_param, allowed_effects) + _make_closed_jaxpr, _prune_zeros, _typecheck_param) _map = safe_map zip = safe_zip @@ -258,7 +258,7 @@ def _create_jaxpr(init): in_flat, jaxpr, consts, out_tree, out_tree_children = rest _check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out) - disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') @@ -1235,8 +1235,8 @@ def _create_jaxpr(init_val): _check_tree_and_avals("body_fun output and input", body_tree, body_jaxpr.out_avals, in_tree_children[0], init_avals) - effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) - disallowed_effects = allowed_effects.filter_not_in(effects) + joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') @@ -1268,7 +1268,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, del avals joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) - disallowed_effects = allowed_effects.filter_not_in(joined_effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') @@ -1698,7 +1698,7 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, # TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) - disallowed_effects = allowed_effects.filter_not_in(joined_effects) + disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index b6906a529bf9..4d45c46c3622 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -558,11 +558,11 @@ def test_runtime_tokens_should_update_after_running_effectful_function(self): def f(x): effect_p.bind(effect=foo_effect) return x + 1. - self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens) + self.assertNotIn(foo_effect, dispatch.runtime_tokens.current_tokens) f(2.) - prev_token = dispatch.runtime_tokens.tokens[foo_effect] + prev_token = dispatch.runtime_tokens.current_tokens[foo_effect] f(2.) - curr_token = dispatch.runtime_tokens.tokens[foo_effect] + curr_token = dispatch.runtime_tokens.current_tokens[foo_effect] self.assertIsNot(prev_token, curr_token) def test_can_lower_multiple_effects(self): @@ -575,19 +575,19 @@ def f(x): def g(x): effect_p.bind(effect=foo_effect) return x + 1. - self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens) - self.assertNotIn(foo2_effect, dispatch.runtime_tokens.tokens) - f(2.).block_until_ready() - foo_token = dispatch.runtime_tokens.tokens[foo_effect][0] - foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0] + self.assertNotIn(foo_effect, dispatch.runtime_tokens.current_tokens) + self.assertNotIn(foo2_effect, dispatch.runtime_tokens.current_tokens) f(2.) - self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0]) - self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0]) - foo_token = dispatch.runtime_tokens.tokens[foo_effect][0] - foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0] + foo_token = dispatch.runtime_tokens.current_tokens[foo_effect] + foo2_token = dispatch.runtime_tokens.current_tokens[foo2_effect] + f(2.) + self.assertIsNot(foo_token, dispatch.runtime_tokens.current_tokens[foo_effect]) + self.assertIsNot(foo2_token, dispatch.runtime_tokens.current_tokens[foo2_effect]) + foo_token = dispatch.runtime_tokens.current_tokens[foo_effect] + foo2_token = dispatch.runtime_tokens.current_tokens[foo2_effect] g(2.) - self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0]) - self.assertIs(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0]) + self.assertIsNot(foo_token, dispatch.runtime_tokens.current_tokens[foo_effect]) + self.assertIs(foo2_token, dispatch.runtime_tokens.current_tokens[foo2_effect]) class EffectOrderingTest(jtu.JaxTestCase): @@ -608,7 +608,6 @@ def f(x): jax.effects_barrier() self.assertListEqual(log, [2., 3.]) - @jtu.skip_on_devices("tpu") def test_ordered_effect_remains_ordered_across_multiple_devices(self): if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") @@ -640,18 +639,21 @@ def g(x): expected_log = [x_, y_, x_, y_, x_, y_] self.assertListEqual(log, expected_log) - @jtu.skip_on_devices("tpu") def test_different_threads_get_different_tokens(self): if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") tokens = [] def _noop(_): - tokens.append(dispatch.runtime_tokens.tokens[log_effect][0]) return () - @functools.partial(jax.jit, device=jax.devices()[0]) def f(x): - return callback_p.bind(x, callback=_noop, effect=log_effect, out_avals=[]) + # Runs in a thread. + res = jax.jit( + lambda x: callback_p.bind( + x, callback=_noop, effect=log_effect, out_avals=[]) + )(x) + tokens.append(dispatch.runtime_tokens.current_tokens[log_effect]) + return res t1 = threading.Thread(target=lambda: f(2.)) t2 = threading.Thread(target=lambda: f(3.)) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 88d191fe8a42..6f8117991469 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -13,7 +13,9 @@ # limitations under the License. import functools +import logging import textwrap +import time import unittest from absl.testing import absltest @@ -27,11 +29,11 @@ from jax._src import util from jax._src import xla_bridge from jax._src.lib import xla_client +from jax.experimental import io_callback from jax.experimental import maps from jax.experimental import pjit from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map -from jax.experimental import io_callback import jax.numpy as jnp from jax.sharding import Mesh import numpy as np @@ -1007,54 +1009,118 @@ def f(x, y): "Effects not supported in partial-eval of `checkpoint`"): f(2., 3.) - def test_can_use_io_callback_in_pjit(self): - _mut = 0 - def _cb(x): - nonlocal _mut - _mut = x.sum() - - def f(x): - io_callback(_cb, None, x) - return x - - mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev']) - spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')) - out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - f = pjit.pjit(f, in_shardings=spec, out_shardings=out_spec) - with mesh: - f(jnp.arange(mesh.size)) - jax.effects_barrier() - self.assertEqual(_mut, jnp.arange(mesh.size).sum()) - - def test_can_use_io_callback_in_pjit_with_sharding(self): - mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev']) - spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')) + @parameterized.named_parameters( + dict( + testcase_name=f'{ordered=}_{with_sharding=}', + ordered=ordered, + with_sharding=with_sharding, + ) + for ordered in [True, False] + for with_sharding in [True, False] + ) + def test_can_use_io_callback_in_pjit( + self, *, ordered: bool, with_sharding: bool + ): + devices = jax.devices() + mesh = jax.sharding.Mesh(np.array(devices), ['dev']) - _mut = 0 + _collected: list[int] = [] def _cb(x): - nonlocal _mut - _mut = x.sum() + nonlocal _collected + _collected.append(int(x.sum())) - callback_device = jax.devices()[-1] - callback_device_index = spec._device_assignment.index(callback_device) + io_callback_kwargs = dict(ordered=ordered) + callback_device = devices[0] + if with_sharding: + callback_device = devices[-1] + io_callback_kwargs['sharding'] = jax.sharding.SingleDeviceSharding( + callback_device + ) def f(x): - sharding = jax.sharding.SingleDeviceSharding(callback_device) - io_callback(_cb, None, x, sharding=sharding) + io_callback(_cb, None, x, **io_callback_kwargs) + io_callback(_cb, None, x + 1, **io_callback_kwargs) return x + in_spec = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('dev') + ) out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) - f = pjit.pjit(f, in_shardings=spec, out_shardings=out_spec) - inp = jnp.arange(mesh.size) + f = pjit.pjit(f, in_shardings=in_spec, out_shardings=out_spec) + expected = [] with mesh: - f(inp) - jax.effects_barrier() - self.assertEqual(_mut, jnp.arange(mesh.size).sum()) + x = jnp.arange(mesh.size) + f(x) + expected.extend([int(x.sum()), int((x + 1).sum())]) + f(x + 5) + expected.extend([int((x + 5).sum()), int((x + 6).sum())]) + + jax.effects_barrier() + if ordered: + self.assertAllClose(_collected, expected) + else: + self.assertEqual(len(_collected), len(expected)) + for v in expected: + self.assertIn(v, _collected) + callback_device_index = in_spec._device_assignment.index(callback_device) self.assertIn( f'{{maximal device={callback_device_index}}}', - str(f.lower(inp).compiler_ir(dialect='stablehlo')), + str(f.lower(x).compiler_ir(dialect='stablehlo')), ) + def test_sequence_pjit_io_callback_ordered(self): + # A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each + # pair on a different device assignment. + _collected: list[int] = [] + def _cb(i, x): + nonlocal _collected + # Sleep different amounts of time, to test the ordering. + time.sleep([0.02, 0.03, 0.04][len(_collected) % 3]) + logging.info('Collected iteration %s: %s', i, x) + _collected.append(int(x.sum())) + + def f_base(i, x): + io_callback(_cb, None, i, x, ordered=True) + io_callback(_cb, None, i, x + 1, ordered=True) + + nr_iterations = 8 + # TODO(zce): If I pin to 1 device below (jax.devices()[:1]) then this test + # flakes. It also flakes when pinned to 2 devices. It seems that repeatedly + # dispatching to the same device triggers the problem. + devices = jax.devices() + expected = [] # The expected value for _collected + for i in range(nr_iterations): + if len(devices) > 1: + devices_for_iteration = [ + devices[i % len(devices)], + devices[(i + 1) % len(devices)], + ] + else: + devices_for_iteration = devices + logging.info( + 'Running iteration %d on devices %s', i, devices_for_iteration + ) + mesh = jax.sharding.Mesh(np.array(devices_for_iteration), ['dev']) + in_spec = ( + jax.sharding.NamedSharding(mesh, None), + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')), + ) + out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + f = pjit.pjit(f_base, in_shardings=in_spec, out_shardings=out_spec) + with mesh: + x = jax.device_put( + np.arange(len(devices_for_iteration), dtype=np.int32) + 10 * i, + in_spec[1], + ) + f(i, x) + expected.extend([int(x.sum()), int((x + 1).sum())]) + f(i, x + 5) + expected.extend([int((x + 5).sum()), int((x + 6).sum())]) + + jax.effects_barrier() + self.assertEqual(_collected, expected) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())