Skip to content

Commit

Permalink
[callback] Add a flag to implement host_callback in terms of io_callb…
Browse files Browse the repository at this point in the history
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
  • Loading branch information
gnecula committed Apr 2, 2024
1 parent 4c41c12 commit 69e1202
Show file tree
Hide file tree
Showing 5 changed files with 545 additions and 147 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Remember to align the itemized text with the first line of an item within a list
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
new callbacks. See {jax-issue}`#20385` for a discussion.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
Expand Down Expand Up @@ -1426,7 +1428,7 @@ Changes:
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
Expand Down
6 changes: 5 additions & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,11 @@ pytype_library(

pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
srcs = [
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
"experimental/host_callback.py",
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
],
visibility = ["//visibility:public"],
deps = [
":jax",
Expand Down
114 changes: 99 additions & 15 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
Expand Down Expand Up @@ -501,6 +502,7 @@ def power3_with_cotangents(x):
from __future__ import annotations

import atexit
import enum
from collections.abc import Sequence
import functools
import itertools
Expand All @@ -510,13 +512,15 @@ def power3_with_cotangents(x):
import traceback
from typing import Any, Callable, cast

import jax
from jax._src import api
from jax._src import core
from jax._src import config
from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.experimental import io_callback
from jax._src.interpreters import ad, batching, pxla
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
Expand Down Expand Up @@ -560,6 +564,15 @@ def power3_with_cotangents(x):
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
_HOST_CALLBACK_LEGACY = config.DEFINE_bool(
'jax_host_callback_legacy',
config.bool_env('JAX_HOST_CALLBACK_LEGACY', False),
help=(
'Use old implementation of host_callback, documented in the module docstring.'
'If False, use the jax.experimental.io_callback implementation. '
'See https://github.com/google/jax/issues/20385.'
)
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -591,20 +604,30 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
XlaLocalClient = xla_client.Client
DType = Any

class CallbackFlavor(enum.Enum):
"""Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
See https://github.com/google/jax/issues/20385."""
IO_CALLBACK = 1 # uses jax.experimental.io_callback
PURE = 2 # uses jax.pure_callback
DEBUG = 3 # uses jax.debug.callback, valid only when there are no results


def _deprecated_id_tap(tap_func,
arg,
*,
result=None,
tap_with_device=False,
device_index=0,
callback_flavor=CallbackFlavor.IO_CALLBACK,
**kwargs):
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
Expand All @@ -628,6 +651,9 @@ def _deprecated_id_tap(tap_func,
device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
Returns:
``arg``, or ``result`` if given.
Expand Down Expand Up @@ -660,7 +686,8 @@ def _deprecated_id_tap(tap_func,
call_with_device=tap_with_device,
result_shape=None,
identity=True,
device_index=device_index)
device_index=device_index,
callback_flavor=callback_flavor)

if result is not None:
return result
Expand All @@ -675,13 +702,15 @@ def _deprecated_id_print(arg,
device_index=0,
output_stream=None,
threshold=None,
callback_flavor=CallbackFlavor.IO_CALLBACK,
**kwargs):
"""Like :func:`id_tap` with a printing tap function.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
Expand All @@ -697,6 +726,9 @@ def _deprecated_id_print(arg,
built-in ``print``. The string will be passed as
``output_stream.write(s)``.
* ``threshold`` is passed to ``numpy.array2string``.
* ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
Expand All @@ -708,19 +740,22 @@ def _deprecated_id_print(arg,
arg,
result=result,
tap_with_device=tap_with_device,
device_index=device_index)
device_index=device_index,
callback_flavor=callback_flavor)


def _deprecated_call(callback_func: Callable, arg, *,
result_shape=None,
call_with_device=False,
device_index=0):
device_index=0,
callback_flavor=CallbackFlavor.IO_CALLBACK):
"""Make a call to the host, and expect a result.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/google/jax/issues/20385.
Args:
callback_func: The Python function to invoke on the host as
Expand Down Expand Up @@ -748,14 +783,26 @@ def _deprecated_call(callback_func: Callable, arg, *,
device_index: specifies from which device the tap function is invoked in a
SPMD program. Works only when using the outfeed implementation mechanism,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
See https://github.com/google/jax/issues/20385.
Returns:
the result of the ``callback_func`` invocation.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
if (not _HOST_CALLBACK_LEGACY.value and
callback_flavor == CallbackFlavor.DEBUG and
result_shape is not None):
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
"flavor of callback only when the `result_shape` is None. "
"See https://github.com/google/jax/issues/20385."
)
return _call(callback_func, arg, result_shape=result_shape,
call_with_device=call_with_device, identity=False,
device_index=device_index)
device_index=device_index, callback_flavor=callback_flavor)


# We need the wrapper function to have hash and equality defined since it is
Expand All @@ -766,6 +813,11 @@ def __init__(self, callback_func, identity, call_with_device):
self.callback_func = callback_func
self.identity = identity
self.call_with_device = call_with_device
if not _HOST_CALLBACK_LEGACY.value and call_with_device:
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
" do not support `tap_with_device` and `call_with_device`. "
"See https://github.com/google/jax/issues/20385.")

def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device))
Expand All @@ -775,7 +827,16 @@ def __eq__(self, other):
self.identity == other.identity and
self.call_with_device == other.call_with_device)

def __call__(self, arg, device, transforms):
def __call__(self, *args, **kwargs):
if _HOST_CALLBACK_LEGACY.value:
return self._call_legacy(*args, **kwargs)
else:
if self.identity:
# For id_tap, we pass empty transforms, for backwards compatibility
return self.callback_func(args[0], ())
return self.callback_func(*args, **kwargs)

def _call_legacy(self, arg, device, transforms):
if self.identity:
# For id_tap, we pass the transforms, for backwards compatibility
if self.call_with_device:
Expand All @@ -797,14 +858,16 @@ def _call(callback_func: Callable,
result_shape=None,
call_with_device=False,
device_index=0,
identity=False):
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
identity=False,
callback_flavor=CallbackFlavor.IO_CALLBACK):
if _HOST_CALLBACK_LEGACY.value:
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
api.check_callable(callback_func)
flat_args, arg_treedef = tree_util.tree_flatten(arg)
for arg in flat_args:
dispatch.check_arg(arg)
for arg_ in flat_args:
dispatch.check_arg(arg_)
# See definition of outside_call_p for what parameters it takes
params: dict[str, Any] = {}
# TODO: wrap function
Expand All @@ -829,8 +892,26 @@ def _call(callback_func: Callable,

params["result_treedef"] = result_treedef
params["flat_results_aval"] = tuple(flat_results_aval)
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)

if _HOST_CALLBACK_LEGACY.value:
flat_results = outside_call_p.bind(*flat_args, **params)
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
else:
callback_device = jax.local_devices()[device_index]
sharding = jax.sharding.SingleDeviceSharding(callback_device)
callback_func = _CallbackWrapper(callback_func, identity,
call_with_device)
if callback_flavor == CallbackFlavor.DEBUG:
assert identity
jax.debug.callback(callback_func, arg)
elif callback_flavor == CallbackFlavor.PURE:
call_res = jax.pure_callback(callback_func, result_shape, arg,
sharding=sharding)
else:
call_res = io_callback(callback_func, result_shape, arg,
sharding=sharding,
ordered=True)
return call_res if not identity else arg


# We need the lock for when we use the CustomCall implementation of callbacks.
Expand All @@ -855,7 +936,6 @@ def _print_tap_func(
threshold: the value of numpy.array2string threshold parameter.
**kwargs: all other keyword args are printed before printing `arg`.
"""

def emit_str(s: str):
if output_stream is not None:
output_stream.write(s + "\n")
Expand Down Expand Up @@ -1844,6 +1924,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None):
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
if not _HOST_CALLBACK_LEGACY.value:
jax.effects_barrier()
return

logging_name = logging_name or ""
logger.debug("barrier_wait[%s]: start", logging_name)

Expand Down Expand Up @@ -1907,7 +1991,7 @@ def _deprecated_stop_outfeed_receiver():
_deprecation_msg = (
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
"is subsumed by the new JAX external callbacks. "
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.")
"See https://github.com/google/jax/issues/20385.")

_deprecations = {
# Added March 20, 2024
Expand Down
Loading

0 comments on commit 69e1202

Please sign in to comment.