Skip to content

Commit

Permalink
[callbacks] Add support for shardable ordered effects.
Browse files Browse the repository at this point in the history
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.

Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.

In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.

We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.

PiperOrigin-RevId: 566242557
  • Loading branch information
gnecula authored and jax authors committed Sep 18, 2023
1 parent 152af70 commit 32ee27b
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 141 deletions.
6 changes: 2 additions & 4 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@

logger = logging.getLogger(__name__)

allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects

### Policies

def everything_saveable(*_, **__) -> bool:
Expand Down Expand Up @@ -483,11 +481,11 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
return out_primals, out_tangents
ad.primitive_jvps[remat_p] = remat_jvp

allowed_effects.add_type(lax_internal.InOutFeedEffect)
effects.remat_allowed_effects.add_type(lax_internal.InOutFeedEffect)

def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
disallowed_effects = effects.remat_allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
'Effects not supported in partial-eval of `checkpoint`/`remat`: '
Expand Down
1 change: 1 addition & 0 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class OrderedIOEffect(effects.Effect):
effects.control_flow_allowed_effects.add_type(IOEffect)
effects.control_flow_allowed_effects.add_type(OrderedIOEffect)
effects.ordered_effects.add_type(OrderedIOEffect)
effects.shardable_ordered_effects.add_type(OrderedIOEffect)


def io_callback_impl(
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
map = safe_map
zip = safe_zip

allowed_effects: effects.EffectTypeSet = (
effects.custom_derivatives_allowed_effects)

### util

Expand Down Expand Up @@ -416,15 +414,15 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
yield outs, tuple(todo) # Ensure the aux output is immutable


allowed_effects.add_type(lax.InOutFeedEffect)
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)

custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
num_consts, symbolic_zeros):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects)
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
Expand Down Expand Up @@ -817,7 +815,7 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
return core.jaxpr_as_fun(fun_jaxpr)(*args)

def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
disallowed_effects = allowed_effects.filter_not_in(fun_jaxpr.effects)
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_vjp`: {disallowed_effects}')
Expand Down
68 changes: 30 additions & 38 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np

import jax
from jax._src import basearray
from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -187,57 +188,48 @@ def simple_impl(prim):
RuntimeToken = Any

class RuntimeTokenSet(threading.local):
tokens: dict[core.Effect, tuple[RuntimeToken, Device]]
output_tokens: dict[Device, RuntimeToken]
"""See docstring for effect.py module for the calling convention for tokens."""

# For each ordered effect, the token returned by the last dispatched
# computation, sharded over the devices in that computation.
current_tokens: dict[core.Effect, jax.Array]

# For each device, the runtime token returned by the last dispatched
# computation on that device.
output_runtime_tokens: dict[Device, RuntimeToken]

def __init__(self):
self.tokens = {}
# TODO(sharadmv): remove redundant output token dictionary when minimum
# jaxlib version is bumped to 0.3.16.
self.output_tokens = {}
self.current_tokens = {}
self.output_runtime_tokens = {}

def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
s = SingleDeviceSharding(device)
if eff not in self.tokens:
inp = np.zeros(0, np.bool_)
indices = tuple(
s.addressable_devices_indices_map(inp.shape).values())
out = pxla.shard_args([device], [indices], [s], [inp])
self.tokens[eff] = out, device
elif self.tokens[eff][1] != device:
(old_token,), _ = self.tokens[eff]
indices = tuple(
s.addressable_devices_indices_map((0,)).values())
out = pxla.shard_args([device], [indices], [s], [old_token])
self.tokens[eff] = out, device
return self.tokens[eff][0]

def update_token(self, eff: core.Effect, token: RuntimeToken):
self.tokens[eff] = token, self.tokens[eff][1]

def set_output_token(self, device: Device, token: RuntimeToken):
# We're free to clobber the previous output token because on each
# device we have a total ordering of computations. Only the token
# from the latest computation matters. If this weren't the case
# we'd need to store a set of output tokens.
self.output_tokens[device] = token
def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
s = NamedSharding(pxla.Mesh(devices, axis_names=["dev"]),
PartitionSpec([]))
s = jax.sharding.GSPMDSharding.get_replicated(devices)
indices = tuple(
s.addressable_devices_indices_map(tok.shape).values())
sharded_tok = pxla.shard_args(devices, [indices], [s], [tok])[0]
self.current_tokens[eff] = sharded_tok
return sharded_tok

def set_token_result(self, eff: core.Effect, token: jax.Array):
self.current_tokens[eff] = token

def set_output_runtime_token(self, device: Device, token: RuntimeToken):
# TODO(sharadmv): remove this method when minimum jaxlib version is bumped
# We're free to clobber the previous output token because on each
# device we have a total ordering of computations. Only the token
# from the latest computation matters.
self.output_runtime_tokens[device] = token

def clear(self):
self.tokens = {}
self.output_tokens = {}
self.current_tokens = {}
self.output_runtime_tokens = {}

def block_until_ready(self):
for token, _ in self.tokens.values():
token[0].block_until_ready()
for token in self.output_tokens.values():
token[0].block_until_ready()
for token in self.current_tokens.values():
token.block_until_ready()
for token in self.output_runtime_tokens.values():
token.block_until_ready()
self.clear()
Expand Down
48 changes: 48 additions & 0 deletions jax/_src/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX effects.
JAX uses effects to describe computations that may have side-effects. Effects
are associated with JAX primitive instances and Jaxprs.
A primitive instance with an effect will be protected from dead-code elimination
even if its result is unused.
A special class of effects are the **ordered** effects
(members of `effects.ordered_effects`).
The lowering of a computation with ordered effects will have one additional
input and one additional output for each ordered effect. These appear before
the regular inputs/outputs, and are of type `i1[0]`. These tokens
are threaded through the instructions with ordered effects to ensure that the
compiler will not eliminate, replicate, or reordered the corresponding
instructions.
To ensure the ordering across multiple computations we maintain a
per-thread set of the tokens returned by the last dispatched computation. There
is one token per ordered effect, and it may be sharded over the devices
used by the last dispatched computation. Upon dispatching a
new computation with ordered effects we take the current token, we shard it
on the devices for the computation to be dispatched and we pass it as an input.
Then we update the current token to refer to the token output of
the dispatched computation.
When we have ordered effects, we also use the current token to implement
`jax.barrier` which waits until the current tokens are ready.
The implementation of `jax.barrier` for unordered effects is a bit different,
because for these effects we do not thread tokens in and out of dispatched
computation. Instead, we use a `RuntimeToken`, which is an object returned when
dispatching a computation and on which we can block until is ready. We store
for each thread the `RuntimeToken` returned by the last dispatched computation.
For more details, see the design note:
https://jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html.
"""

from __future__ import annotations

from collections.abc import Iterable
from typing import Any


class Effect:
"""A generic side-effect."""

Expand Down Expand Up @@ -66,6 +106,14 @@ def filter_not_in(self, effects: Iterable[Effect]) -> list[Effect]:

no_effects: Effects = set()
ordered_effects: EffectTypeSet = EffectTypeSet()

# By default, ordered effects are not allowed in multi-device computations,
# because we cannot ensure a total order. Optionally, an effect can be
# declared as shardable, which means that effects will appear in program order
# but for a given program point we may see several side effects on the
# participating devices, and there is no guarantee of their relative ordering.
shardable_ordered_effects: EffectTypeSet = EffectTypeSet()

lowerable_effects: EffectTypeSet = EffectTypeSet()
control_flow_allowed_effects: EffectTypeSet = EffectTypeSet()
custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet()
Expand Down
43 changes: 18 additions & 25 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,38 +1109,27 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
self.has_unordered_effects = bool(unordered_effects)
self.ordered_effects = ordered_effects
self._local_devices = self.xla_executable.local_devices()
if ordered_effects:
assert len(self._local_devices) == 1
self.keepalive = keepalive
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx

def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
device, = self._local_devices
tokens = [list(dispatch.runtime_tokens.get_token(eff, device))
for eff in self.ordered_effects]
tokens = [
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)
for eff in self.ordered_effects]
input_bufs = [*tokens, *input_bufs]
return input_bufs

def _handle_token_bufs(self, token_bufs, sharded_token):
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
# token buffer (as a singleton list).
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
for i, device in enumerate(self._local_devices):
dispatch.runtime_tokens.set_output_runtime_token(
device, sharded_token.get_token(i))
for eff, token_buf in zip(self.ordered_effects, token_bufs):
dispatch.runtime_tokens.update_token(eff, token_buf)

def _call_with_tokens(self, input_bufs):
input_bufs = self._add_tokens_to_inputs(input_bufs)
out_bufs, sharded_token = (
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
input_bufs
)
)
num_output_tokens = len(self.ordered_effects)
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
self._handle_token_bufs(token_bufs, sharded_token)
return out_bufs
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])

@profiler.annotate_function
def __call__(self, *args):
Expand All @@ -1152,10 +1141,10 @@ def __call__(self, *args):
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
)
self._handle_token_bufs(
results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects)),
results.consume_token())
result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects))
sharded_runtime_token = results.consume_token()
self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
else:
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
Expand Down Expand Up @@ -1833,9 +1822,13 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
module_name = f"{api_name}_{fun_name}"

if len(device_assignment) > 1:
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
unsupported_effects)
if len(unsupported_effects) > 0:
raise ValueError(
"The following ordered effects are not supported for "
f"more than 1 device: {unsupported_effects}")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))

with dispatch.log_elapsed_time(
Expand Down
1 change: 0 additions & 1 deletion jax/_src/lax/control_flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
_custom_linear_solve_impl,
linear_solve_p)

from jax._src.lax.control_flow.common import allowed_effects
# Private utilities used elsewhere in JAX
# TODO(sharadmv): lift them into a more common place
from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax._src import core
from jax._src import linear_util as lu
from jax._src.lax import lax
from jax._src.effects import control_flow_allowed_effects as allowed_effects
from jax._src import effects
from jax._src import ad_util
from jax._src import state
from jax._src import util
Expand All @@ -33,7 +33,7 @@

map, unsafe_map = safe_map, map

allowed_effects.add_type(lax.InOutFeedEffect)
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)


def _abstractify(x):
Expand Down
9 changes: 4 additions & 5 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
allowed_effects,
)

map, unsafe_map = safe_map, map
Expand Down Expand Up @@ -144,7 +143,7 @@ def switch(index, branches, *operands):
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `switch`: {disallowed_effects}')
Expand Down Expand Up @@ -253,7 +252,7 @@ def cond(pred, true_fun, false_fun, *operands):
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
Expand Down Expand Up @@ -325,7 +324,7 @@ def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects:

def _cond_abstract_eval(*avals, branches, **_):
joined_effects = _join_cond_effects(branches)
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
Expand Down Expand Up @@ -765,7 +764,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, linear):
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
joined_effects = _join_cond_effects(branches)
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
Expand Down
Loading

0 comments on commit 32ee27b

Please sign in to comment.