From 4716b9830f23f7a269964c72fb21828b30ac40cf Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 5 Oct 2024 07:32:46 -0700 Subject: [PATCH] [host_callback] Remove most of the jax.experimental.host_callback module These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See https://github.com/google/jax/issues/20385 for a discussion. PiperOrigin-RevId: 682659677 --- CHANGELOG.md | 3 + jax/experimental/host_callback.py | 2008 +---------------------------- tests/BUILD | 16 - tests/host_callback_test.py | 1787 ------------------------- tests/host_callback_to_tf_test.py | 279 ---- tests/infeed_test.py | 3 - 6 files changed, 40 insertions(+), 4056 deletions(-) delete mode 100644 tests/host_callback_to_tf_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fd4a71aeeb50..d3f29b30882f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.numpy.isscalar` now returns True for any array-like object with zero dimensions. Previously it only returned True for zero-dimensional array-like objects with a weak dtype. + * `jax.experimental.host_callback` has been deprecated since March 2024, with + JAX version 0.4.26. Now we removed it. + See {jax-issue}`#20385` for a discussion of alternatives. ## jax 0.4.34 (October 4, 2023) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 1ab44a4fd586..f6f51ba5796a 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Primitives for calling Python functions on the host from JAX accelerator code. +"""Backwards compatibility shim for the deprecated host_callback APIs. .. warning:: The host_callback APIs are deprecated as of March 20, 2024. @@ -19,737 +19,30 @@ `new JAX external callbacks `_ See https://github.com/jax-ml/jax/issues/20385. -This module introduces the host callback functions :func:`call`, -:func:`id_tap`, and :func:`id_print`, that send their arguments from the device -to the host and invoke user-defined Python functions on the host, optionally -returning results back to the device computation. - -We show below how these functions can be used. We start with :func:`call`, -and we discuss examples of calling from JAX to arbitrary Python functions -on the CPU, e.g., to use NumPy CPU custom kernels. Then we -show uses of :func:`id_tap` and :func:`id_print`, which have the restriction -that they cannot return values from the host to the device. -These primitives are generally faster -because they are executed asynchronously with the device code. -In particular, they can be used to tap into and to debug JAX code. - -Using :func:`call` to call a host function and return results to device ------------------------------------------------------------------------ - -Use :func:`call` to invoke a computation on the host and return -NumPy arrays to the device computation. -Host computation is useful, e.g., when a device computation needs some data -that requires I/O on the host, or it needs a library that is available on the -host and you do not want to code it in JAX. -For example, eigen decomposition for general matrices in JAX does not work on TPU. -We can call the Numpy implementation from any JAX accelerator computation, -using a host computation:: - - # This function runs on the host - def host_eig(m: np.ndarray) -> np.ndarray: - return np.linalg.eigvals(m) - - # This function is used in JAX - def device_fun(m): - # We send "m" to the host, asking it to call "host_eig" and return the result. - # We have to specify the result shape and dtype, either in the form of an - # example return value or any object that has `shape` and `dtype` attributes, - # e.g., a NumPy array or a `jax.ShapeDtypeStruct`. - return hcb.call(host_eig, m, - # Given an input of shape (..., d, d), eig output has shape (..., d) - result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype)) - - -The :func:`call` function and the Python host function both take a single argument -and return a single result, but those can be pytrees. Note that we must tell -the :func:`call` what shape and dtype to expect from the host invocation, using -the ``result_shape`` keyword argument. -This is important because the device code is compiled with that expectation. -There will be an error raised at runtime if the actual invocation produces a -different result shape. In general, **such errors and also exceptions raised -by the host computation may be difficult to debug**. See the Debugging section -below. -This is a problem for :func:`call` but not for :func:`id_tap` because for the -latter the device code does not expect a returned value. - -The :func:`call` API can be used inside a jit or pmap computation or inside -cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be -separate calls to the host from each of the participating devices:: - - def host_sin(x, *, device): - # The ``device`` argument is passed due to ``call_with_device=True`` below. - print(f"Invoking host_sin with {x.shape} on {device}") - return np.sin(x) - - # Use pmap to run the computation on two devices - jax.pmap(lambda x: hcb.call(host_sin, x, - result_shape=x, - # Ask that the `host_sin` function be passed `device=dev` - call_with_device=True))( - np.ones((2, 4), dtype=np.float32)) - - # prints (in arbitrary order) - # Invoking host_sin with (4,) on cpu:0 - # Invoking host_sin with (4,) on cpu:1 - -Note that :func:`call` does not support any JAX transformations, but as we -show below one can make use of the -existing support for `Custom differentiation in JAX `_. - -Using :func:`id_tap` to call a Python function on the host, with no returned values ------------------------------------------------------------------------------------ - -The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when -you just want the side effects of your Python callback. These functions have -the advantage that once the arguments have been sent to the host, the device -computation can proceed without waiting for the Python callback to return. -For :func:`id_tap` you can specify your Python callback to be called, while -:func:`id_print` uses a built-in callback that prints the arguments to -`stdout` on the host. -The Python function passed -to :func:`id_tap` takes two positional arguments (the value tapped -from the device computation along with a ``transforms`` tuple, -described below). Optionally, the function may be passed a keyword argument -``device`` with the Device from which the value was tapped. - -A few examples:: - - def host_func(arg, transforms): - ...do something with arg... - - # calls host_func(2x, []) on host - id_tap(host_func, 2 * x) - - # calls host_func((2x, 3x), []) - id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree - - # calls host_func(2x, [], device=jax.devices()[0]) - id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap - - # calls host_func(2x, [], what='activation') - id_tap(functools.partial(host_func, what='activation'), 2 * x) - - # calls host_func(dict(x=x, y=y), what='data') - id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y)) - -The above examples can all be adapted to use :func:`id_print` instead, with -the difference that :func:`id_print` prints on the host the positional argument, -along with any additional kwargs and the automatic kwarg ``transforms``. - -Using :func:`barrier_wait` to wait until all callbacks have executed --------------------------------------------------------------------- - -If your Python callbacks have side-effects you may need to wait until the -computation has finished to ensure that the side-effects have been observed. -You can use the :func:`barrier_wait` function for that purpose:: - - accumulator = [] - def host_log(arg, transforms): - # We just record the arguments in a list - accumulator.append(arg) - - - def device_fun(x): - id_tap(host_log, x) - id_tap(host_log, 2. * x) - - jax.jit(device_fun)(1.) - jax.jit(device_fun)(1.) - - # At this point, we have started two computations, each with two - # taps, but they may not have yet executed. - barrier_wait() - # Now we know that all the computations started before `barrier_wait` - # on all devices, have finished, and all the callbacks have finished - # executing. - -Note that :func:`barrier_wait` will start one -tiny computation with one tap on each of the `jax.local_devices()` and -will wait for all these taps to be received. - -An alternative to using :func:`barrier_wait` is to just wait for the end -of the computation, if all the callbacks are :func:`call`:: - - accumulator = p[] - def host_log(arg): - # We just record the arguments in a list - accumulator.append(arg) - return 0. # return something - - - def device_fun(c): - y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32)) - z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32)) - return y + z # return something that uses both results - - res1 = jax.jit(device_fun)(1.) - res2 = jax.jit(device_fun)(1.) - res1.block_until_ready() - res2.block_until_ready() - -Behavior under parallelization transformations ----------------------------------------------- - -In presence of :func:`jax.pmap` the code will run on multiple devices and -each device will tap its values independently. -It may be helpful to use the ``tap_with_device`` option for :func:`id_print` -or :func:`id_tap`, so that you see which device is sending which data:: - - jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.]) - # device=cpu:0 what=x,x^2: (3., 9.) # from the first device - # device=cpu:1 what=x,x^2: (4., 16.) # from the second device - -When using :func:`jax.pmap` with multiple devices on multiple hosts, every -host will receive callbacks from all of its local devices, with an operand -that corresponds to each device slice. For a -:func:`call`, the callback must return to each device only the slice of the -result that pertains to the corresponding device. - -When using the experimental :func:`pjit.pjit` the code will run on multiple -devices on different shards of the input. The current implementation of -host callbacks will ensure that a single device will collect and outfeed -the entire operand, in a single callback. The callback function is supposed -to return the entire array, which will then be sent in a single infeed to the -same device that issued the outfeed. This device is then responsible for -sending the required shards to the other devices:: - - with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]): - pjit.pjit(power3, in_shardings=(P("d"),), - out_shardings=(P("d"),))(np.array([3., 4.])) - - # device=TPU:0 what=x,x^2: ( [3., 4.], - # [9., 16.] ) - -Note that the collection of the operand on one device may result in OOM if -the operand was sharded across devices. - -When using :func:`pjit.pjit` with multiple devices on multiple hosts, only -the host for the device 0 (w.r.t. the mesh) will receive the callback, with -the operand collected -from all participating devices on all hosts. For a :func:`call`, the callback -must return the entire array for all devices on all hosts. - -Behavior under JAX autodiff transformations -------------------------------------------- - -When used under a JAX autodiff transformation, the host callback functions -operate on the primal values only. Consider the following example:: - - def power3(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2") - return y * x - - power3(3.) - # what: x,x^2 : (3., 9.) - -(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.) - -When used under :func:`jax.jvp` there will be one callback with the primal -values only:: - - jax.jvp(power3, (3.,), (0.1,)) - # what: x,x^2 : (3., 9.) - -Similarly for :func:`jax.grad`, we get a callback from the forward computation -only:: - - jax.grad(power3)(3.) - # what: x,x^2 : (3., 9.) - -If you want to invoke the callback on the tangents during a :func:`jax.jvp`, -you can use a custom_jvp. For example, you can define a function that does -nothing interesting except that its custom_jvp will print the tangents:: - - @jax.custom_jvp - def print_tangents(arg): - return None - - @print_tangents.defjvp - def print_tangents_jvp(primals, tangents): - arg_dot, = tangents - hcb.id_print(arg_dot, what="tangents") - return primals, tangents - -Then you use this function in the places where you want to tap the tangents:: - - def power3_with_tangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2") - print_tangents((x, y)) - return y * x - - jax.jvp(power3_with_tangents, (3.,), (0.1,)) - # what: x,x^2 : (3., 9.) - # what: tangents : (0.1, 0.6) - -You can do a similar thing for the cotangents during :func:`jax.grad`. This -time you must be careful to use in the rest of the computation the values whose -cotangents you want to tap. Hence we make the ``print_cotangents`` return -its argument:: - - @jax.custom_vjp - def print_cotangents(arg): - # Must return the argument for which we want the cotangent. - return arg - - # f_fwd: a -> (b, residual) - def print_cotangents_fwd(arg): - return print_cotangents(arg), None - # f_bwd: (residual, CT b) -> [CT a] - def print_cotangents_bwd(residual, ct_b): - hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) - return ct_b, - - print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) - - def power3_with_cotangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) - (x1, y1) = print_cotangents((x, y)) - # Must use the output of print_cotangents - return y1 * x1 - - jax.grad(power3_with_cotangents)(3.) - # what: x,x^2 : (3., 9.) - # what: cotangents : (9., 3.) - -If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals -for the backward pass, then the callbacks from the primal computation will -be called twice:: - - jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.) - # what: x,x^2 : (3., 9.) - # what: x,x^2 : (27., 729.) - # what: x,x^2 : (3., 9.) - -The callbacks are, in order from: the primal computation of the inner ``power3``, -the primal computation of the outer ``power3``, and the rematerialization -of the residuals for the inner ``power3``. - - -Behavior under jax.vmap ------------------------ - -The host callback functions :func:`id_print` and :func:`id_tap` support the -vectorization transformation :func:`jax.vmap`. - -For :func:`jax.vmap` the arguments to the callback are batched, -and the callback function is -passed an additional special ``transforms`` containing a list of transformation descriptors -in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the -batched dimensions for the tapped values (one entry per argument, ` -`None`` denotes an argument that was broadcast). - - jax.vmap(power3)(np.array([2., 3.])) - # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.]) - -See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`. - -For more usage example, see tests/host_callback_test.py. - -Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support ------------------------------------------------------------------------------------- - -Another possible use for host computation is to invoke a library written for -another framework, such as TensorFlow. -In this case it becomes interesting to support JAX autodiff for host callbacks -by deferring to the autodiff mechanism in TensorFlow, -using the :func:`jax.custom_vjp` mechanism. - -This is relatively easy to do, once one understands both the JAX custom VJP -and the TensorFlow autodiff mechanisms. -The code for how this can be done is shown in the ``call_tf_full_ad`` -function in `host_callback_to_tf_test.py `_. -This example supports arbitrary higher-order differentiation as well. - -Note that if you just want to call TensorFlow functions from JAX, you can also -use the `jax2tf.call_tf function `_. - -Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support ------------------------------------------------------------------------------------------------- - -It should not be surprising that we can use host computation to invoke a JAX -computation on another device. The arguments are sent from the accelerator to -the host, and then to the outside device on which the JAX host -computation will run, and then the results are sent back to the original accelerator. - -The code for how this can be done is shown in the ``call_jax_other_device function`` -in `host_callback_test.py `_. - -Low-level details and debugging -------------------------------- - -The host callback functions will be executed for each device in the order in -which the send operations were performed on the device. - -The host callback functions for multiple devices may be interleaved. -The data from the devices is received by separate threads managed by the JAX -runtime (one thread per device). The runtime maintains a buffer of -configurable size (see the flag ``--jax_host_callback_max_queue_byte_size``). -When the buffer is full, all the receiving threads are paused -which eventually pauses the computation on devices. The runtime has one -additional thread for each device to invoke the Python user functions with the -received data. If the processing of the callbacks is slow, it may actually -lead to the runtime buffer filling up, and eventually pausing the computation -on the devices when they need to send something. -For more details on the outfeed receiver runtime mechanism see -`runtime code -`_. - -In order to pause the execution until all data from computations already -started on devices has arrived and has been processed, use :func:`barrier_wait`. - -Exceptions from the user-defined callback functions are logged along with their -stack traces, but the receiving threads are not stopped. Instead the last -exception is recorded and the subsequent :func:`barrier_wait` will -raise :exc:`CallbackException` if any exception had occurred -in one of the tap functions. This exception will include the text and the -stack trace of the last exception encountered. - -One further complication arises for callback functions that must return -results to the call origin device, such as :func:`call()`. This is handled -differently on CPU/GPU devices compared to TPU devices. - -On CPU/GPU devices, in order to avoid the device computation -being stuck waiting for a result that will never arrive, in case of any -error during the processing of the callback (whether raised by the user-code -itself or due to a mismatch of the returned value and the expected return_shape) -we send the device a "fake" result of shape ``int8[12345]``. -This will make the device -computation abort because the received data is different than the one that -it expects. On CPU the runtime will crash with a distinctive error message: - -``` -Check failed: buffer->length() == buffer_length (12345 vs. ...) -``` - -On GPU, the failure is more user-friendly and will be surfaced to the Python -program as: - -``` -RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ... -``` - -To debug the underlying cause for these messages, see the Debugging section. - -On TPU devices, there is currently no shape check for infeed, so we take the -safer route of not sending this fake result in case of errors. This means -that the computation will hang, and no exception will be raised (but any -exceptions in the callback functions will still appear in the logs). - -The current implementation uses the outfeed mechanism provided by XLA. The -mechanism itself is quite primitive in the sense that a receiver must know -exactly the shape of each incoming packet, and how many packets are expected. -This makes it hard to use for multiple kinds of data in the same computation, -and it is practically impossible to use it under conditionals or in loops -of non-constant iteration count. Furthermore, code that uses the outfeed -mechanism directly cannot be transformed by JAX. All these limitations are -addressed by the host callback functions. The tapping API introduced here -makes it easy to share the outfeed mechanism for multiple purposes, while -supporting all transformations. - -**Note that after you have used the host callback functions, you cannot -use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver` -if you later need to use lax.outfeed. - -Since the actual calls to your callback functions are made from the C++ -receiver, it may be hard to debug the calls. In particular, the stack trace -will not include the calling code. You can use the flag -``jax_host_callback_inline`` (or the environment variable -``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are -inlined. This works only if the calls are outside a staging context -(:func:`~jax.jit` or a control-flow primitive). - -The C++ `receiver -`_ -is started automatically on the first call to :func:`id_tap`. In order to stop -it properly, upon start an ``atexit`` handler is registered to call -:func:`barrier_wait` with the logging name "at_exit". - -There are a few environment variables that you can use to turn on logging -for the C++ outfeed `receiver backend -`_. - - * ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below. - * ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave - like INFO logs. This may be too much, but you will see which modules are - logging relevant info, and then you can select which modules to log from. - * ``TF_CPP_VMODULE==3`` (the module name can be either C++ or - Python, without the extension). - -You should also use the ``--verbosity=2`` flag so that you see the logs -from Python. - -For example, you can try to enable logging in the ``host_callback`` module: -``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple`` - -If you want to enable logging in lower-level implementation modules try: -``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple`` - -(For bazel tests use --test_arg=--vmodule=... - -Still to do: - * More performance tests. - * Explore implementation with outside compilation for TPU. - * Explore implementation with XLA CustomCall for CPU and GPU. - """ from __future__ import annotations -import atexit -import enum -from collections.abc import Callable, Sequence -import functools -import itertools +from collections.abc import Callable import logging -import math -import threading -import traceback -from typing import Any, cast +import warnings import jax -from jax._src import api -from jax._src import core -from jax._src import config -from jax import custom_derivatives -from jax._src import dtypes -from jax import lax -from jax.experimental import pjit from jax.experimental import io_callback -from jax._src.interpreters import ad, batching, pxla -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla -from jax._src import ad_checkpoint -from jax._src import compiler -from jax._src import dispatch -from jax._src import pretty_printer as pp -from jax._src import sharding_impls -from jax._src import source_info_util -from jax._src import tree_util -from jax._src import util -from jax._src import xla_bridge as xb -from jax._src.lib import xla_client -from jax._src.lib import xla_extension -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo -import numpy as np - - -_HOST_CALLBACK_INLINE = config.bool_flag( - 'jax_host_callback_inline', - config.bool_env('JAX_HOST_CALLBACK_INLINE', False), - help='Inline the host_callback, if not in a staged context.' -) -_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.int_flag( - 'jax_host_callback_max_queue_byte_size', - config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), - help=('The size in bytes of the buffer used to hold outfeeds from each ' - 'device. When this capacity is reached consuming outfeeds from the ' - 'device is paused, thus potentially pausing the device computation, ' - 'until the Python callback consume more outfeeds.'), - lower_bound=int(16 * 1e6) -) -_HOST_CALLBACK_OUTFEED = config.bool_flag( - 'jax_host_callback_outfeed', - config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False), - help=( - 'Use outfeed implementation for host_callback, even on CPU and GPU. ' - 'If false, use the CustomCall implementation. ' - 'Has no effect on TPU, since only the outfeed mechanism is implemented.' - ) -) -_HOST_CALLBACK_LEGACY = config.bool_flag( - 'jax_host_callback_legacy', - config.bool_env('JAX_HOST_CALLBACK_LEGACY', False), - help=( - 'Use old implementation of host_callback, documented in the module docstring.' - 'If False, use the new jax.experimental.io_callback implementation. ' - 'See https://github.com/jax-ml/jax/issues/20385.' - ) -) logger = logging.getLogger(__name__) -def _use_outfeed(platform: str) -> bool: - return (platform in ("tpu", "gpu", "cuda", "rocm") or - _HOST_CALLBACK_OUTFEED.value) - - -def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): - """Should be called whenever outfeed (or infeed) will be used.""" - if xb.using_pjrt_c_api(backend): - raise NotImplementedError( - "host_callback functionality isn't supported with PJRT C API. " - "See https://jax.readthedocs.io/en/latest/debugging/index.html and " - "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" - " for alternatives. Please file a feature request at " - "https://github.com/jax-ml/jax/issues if none of the alternatives are " - "sufficient.") - - -xops = xla_client._xla.ops - -XlaOp = xla_client.XlaOp -XlaShape = xla_client.Shape -XlaBuilder = xla_client.XlaBuilder -XlaDevice = xla_client.Device -XlaLocalClient = xla_client.Client -DType = Any - -class CallbackFlavor(enum.Enum): - """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. - - See https://github.com/jax-ml/jax/issues/20385. - """ - IO_CALLBACK = 1 # uses jax.experimental.io_callback - PURE = 2 # uses jax.pure_callback - DEBUG = 3 # uses jax.debug.callback, valid only when there are no results - - -def _deprecated_id_tap(tap_func, - arg, - *, - result=None, - tap_with_device=False, - device_index=0, - callback_flavor=CallbackFlavor.IO_CALLBACK, - **kwargs): - """Host-callback tap primitive, like identity function with a call to ``tap_func``. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - - ``id_tap`` behaves semantically like the identity function but has the - side-effect that a user-defined Python function is called with the runtime - value of the argument. - - Args: - tap_func: tap function to call like ``tap_func(arg, transforms)``, with - ``arg`` as described below and where ``transforms`` is the sequence of - applied JAX transformations in the form ``(name, params)``. If the - `tap_with_device` optional argument is True, then the invocation also - includes the device from which the value is tapped as a keyword argument: - ``tap_func(arg, transforms, device=dev)``. - arg: the argument passed to the tap function, can be a pytree of JAX - types. - result: if given, specifies the return value of ``id_tap``. This value is - not passed to the tap function, and in fact is not sent from the device to - the host. If the ``result`` parameter is not specified then the return - value of ``id_tap`` is ``arg``. - tap_with_device: if True then the tap function is invoked with the - device from which the tap originates as a keyword argument. - device_index: specifies from which device the tap function is invoked in a - SPMD program. Works only when using the outfeed implementation mechanism, - i.e., does not work on CPU unless --jax_host_callback_outfeed=True. - callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies - the flavor of callback to use. - See https://github.com/jax-ml/jax/issues/20385. - - Returns: - ``arg``, or ``result`` if given. - - The order of execution is by data dependency: after all the arguments and - the value of ``result`` if present, are computed and before the returned - value is used. At least one of the returned values of ``id_tap`` must be - used in the rest of the computation, or else this operation has no effect. - - Tapping works even for code executed on accelerators and even for code under - JAX transformations. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - if kwargs: - msg = ( - "Support for **kwargs in ``id_tap`` has been removed. Instead, " - "pre-apply keyword arguments, either by using a closure or by passing " - "``functools.partial(tap_func, **kwargs)``.") - raise TypeError(msg) - - if result is not None: - flat_results, _ = tree_util.tree_flatten(result) - for r in flat_results: - dispatch.check_arg(r) - - call_res = _call( - tap_func, - arg, - call_with_device=tap_with_device, - result_shape=None, - identity=True, - device_index=device_index, - callback_flavor=callback_flavor) - - if result is not None: - return result - else: - return call_res - - -def _deprecated_id_print(arg, - *, - result=None, - tap_with_device=False, - device_index=0, - output_stream=None, - threshold=None, - callback_flavor=CallbackFlavor.IO_CALLBACK, - **kwargs): - """Like :func:`id_tap` with a printing tap function. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - - On each invocation of the printing tap, the ``kwargs`` if present - will be printed first (sorted by keys). Then arg will be printed, - with the arrays stringified with ``numpy.array2string``. - - See the :func:`id_tap` documentation. - - Additional keyword arguments: - - * ``tap_with_device`` if True, will print also the device from which - the value originates. - * ``output_stream`` if given then it will be used instead of the - built-in ``print``. The string will be passed as - ``output_stream.write(s)``. - * ``threshold`` is passed to ``numpy.array2string``. - * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies - the flavor of callback to use. - See https://github.com/jax-ml/jax/issues/20385. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - printer = functools.partial(_print_tap_func, - output_stream=output_stream, - threshold=threshold, **kwargs) - return _deprecated_id_tap( - printer, - arg, - result=result, - tap_with_device=tap_with_device, - device_index=device_index, - callback_flavor=callback_flavor) - - -def _deprecated_call(callback_func: Callable, arg, *, +# We keep a shim for host_callback.call because it is still used in a few +# places in google. +def call(callback_func: Callable, + arg, + *, result_shape=None, call_with_device=False, device_index=0, - callback_flavor=CallbackFlavor.IO_CALLBACK): + callback_flavor=None): """Make a call to the host, and expect a result. .. warning:: @@ -757,1264 +50,37 @@ def _deprecated_call(callback_func: Callable, arg, *, The functionality is subsumed by the `new JAX external callbacks `_ See https://github.com/jax-ml/jax/issues/20385. - - Args: - callback_func: The Python function to invoke on the host as - ``callback_func(arg)``. If the ``call_with_device`` optional argument is True, - then the invocation also includes the ``device`` kwarg with the device - from which the call originates: ``callback_func(arg, device=dev)``. This function - must return a pytree of numpy ndarrays. - - arg: the argument passed to the callback function, can be a pytree of JAX - types. - - result_shape: a value that describes the expected shape and dtype of the - result. This can be a numeric scalar, from which a shape and dtype are - obtained, or an object that has ``.shape`` and ``.dtype`` attributes. - If the result of the callback is a pytree, then ``result_shape`` should - also be a pytree with the same structure. In particular, ``result_shape`` - can be `()` or `None` if the function does not have any results. - The device code containing ``call`` is compiled with the expected result shape and dtype, - and an error will be raised at runtime if the actual ``callback_func`` - invocation returns a different kind of result. - - call_with_device: if True then the callback function is invoked with the - device from which the call originates as a keyword argument. - - device_index: specifies from which device the tap function is invoked in a - SPMD program. Works only when using the outfeed implementation mechanism, - i.e., does not work on CPU unless --jax_host_callback_outfeed=True. - callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies - the flavor of callback to use. - See https://github.com/jax-ml/jax/issues/20385. - - Returns: - the result of the ``callback_func`` invocation. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. """ - if (not _HOST_CALLBACK_LEGACY.value and - callback_flavor is CallbackFlavor.DEBUG and - result_shape is not None): - raise NotImplementedError( - "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " - "flavor of callback only when the `result_shape` is None. " - "See https://github.com/jax-ml/jax/issues/20385." - ) - return _call(callback_func, arg, result_shape=result_shape, - call_with_device=call_with_device, identity=False, - device_index=device_index, callback_flavor=callback_flavor) - - -# We need the wrapper function to have hash and equality defined since it is -# used as a primitive keyword argument, and we want a compilation cache hit if -# the user uses the same function twice. -class _CallbackWrapper: - def __init__(self, callback_func, identity, call_with_device): - self.callback_func = callback_func - self.identity = identity - self.call_with_device = call_with_device - if not _HOST_CALLBACK_LEGACY.value and call_with_device: - raise NotImplementedError( - "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" - " do not support `tap_with_device` and `call_with_device`. " - "See https://github.com/jax-ml/jax/issues/20385.") - - def __hash__(self): - return hash((self.callback_func, self.identity, self.call_with_device)) - - def __eq__(self, other): - return (self.callback_func == other.callback_func and - self.identity == other.identity and - self.call_with_device == other.call_with_device) - - def __call__(self, *args, **kwargs): - if _HOST_CALLBACK_LEGACY.value: - return self._call_legacy(*args, **kwargs) - else: - if self.identity: - # For id_tap, we pass empty transforms, for backwards compatibility - return self.callback_func(args[0], ()) - return self.callback_func(*args, **kwargs) - - def _call_legacy(self, arg, device, transforms): - if self.identity: - # For id_tap, we pass the transforms, for backwards compatibility - if self.call_with_device: - return self.callback_func(arg, transforms, device=device) - else: - return self.callback_func(arg, transforms) - else: - if self.call_with_device: - return self.callback_func(arg, device=device) - else: - return self.callback_func(arg) - - -# Helper function to implement both `call` and `id_tap`. The two cases are -# differentiated by the `identity` flag. -def _call(callback_func: Callable, - arg, - *, - result_shape=None, - call_with_device=False, - device_index=0, - identity=False, - callback_flavor=CallbackFlavor.IO_CALLBACK): - if _HOST_CALLBACK_LEGACY.value: - # Lazy initialization - _initialize_outfeed_receiver( - max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) - api.check_callable(callback_func) - flat_args, arg_treedef = tree_util.tree_flatten(arg) - for arg_ in flat_args: - dispatch.check_arg(arg_) - # See definition of outside_call_p for what parameters it takes - params: dict[str, Any] = {} - # TODO: wrap function - params["callback"] = _CallbackWrapper(callback_func, identity, - call_with_device) - params["identity"] = identity - params["arg_treedef"] = arg_treedef - params["device_index"] = device_index - - if not identity: - # Turn abstract values into ShapesDtypeStruct - flat_results_shape, result_treedef = tree_util.tree_flatten(result_shape) - try: - flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.dtype(r, canonicalize=True)) - for r in flat_results_shape] - except Exception: - msg = ("result_shape should be a pytree of values with structure " - "matching the expected result of the callback function. The " - "values must be either numeric scalars, or must have 'shape' and " - f"'dtype' attributes. Got {result_shape}") - raise ValueError(msg) - - params["result_treedef"] = result_treedef - params["flat_results_aval"] = tuple(flat_results_aval) - - if _HOST_CALLBACK_LEGACY.value: - flat_results = outside_call_p.bind(*flat_args, **params) - return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) - else: - callback_device = jax.local_devices()[device_index] - sharding = jax.sharding.SingleDeviceSharding(callback_device) - callback_func = _CallbackWrapper(callback_func, identity, - call_with_device) - if callback_flavor is CallbackFlavor.DEBUG: - assert identity - jax.debug.callback(callback_func, arg) - return arg - elif callback_flavor is CallbackFlavor.PURE: - call_res = jax.pure_callback(callback_func, result_shape, arg, - sharding=sharding) - else: - call_res = io_callback(callback_func, result_shape, arg, - sharding=sharding, - ordered=True) - return call_res if not identity else arg - - -# We need the lock for when we use the CustomCall implementation of callbacks. -# The outfeed implementation is driven by a single thread from C++. -_print_tap_lock = threading.Lock() - - -def _print_tap_func( - arg, transforms, *, device=None, - output_stream=None, threshold=1024, **kwargs): - """The consumer for id_print. - - We provide this as a simple tapping function for printing. - This is **experimental** and may not want to add many features to it; - it should be easy for the user to roll their own printing function. - - Args: - device: the device from which the value originates (only if - ``tap_with_device`` was used for :func:`id_print`). - output_stream: a function whose `write` method is called with the strings to - be output. - threshold: the value of numpy.array2string threshold parameter. - **kwargs: all other keyword args are printed before printing `arg`. - """ - def emit_str(s: str): - if output_stream is not None: - output_stream.write(s + "\n") - else: - print(s) - - if transforms: - kwargs['transforms'] = [(name, params) if params else name - for name, params in transforms] - if device is not None: - kwargs['device'] = device - kv_pairs = " ".join([ - f"{k}: {v}" for k, v in sorted(kwargs.items()) - ]) - - def pp_val(arg) -> pp.Doc: - if isinstance(arg, tuple): - return pp.group(pp.concat([ - pp.text("( "), - pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])), - pp.text(" )") - ])) - elif isinstance(arg, list): - return pp.group(pp.concat([ - pp.text("[ "), - pp.nest(2, pp.join(pp.brk(), [pp_val(e) for e in arg])), - pp.text(" ]") - ])) - elif isinstance(arg, dict): - return pp.group(pp.concat([ - pp.text("{ "), - pp.nest(2, pp.join(pp.brk(), [ - pp.text(f"{k}=") + pp_val(v) for k, v in sorted(arg.items()) - ])), - pp.text(" }") - ])) - elif isinstance(arg, np.ndarray): - return pp.text(np.array2string(arg, threshold=threshold)) - else: - return pp.text(str(arg)) - - with _print_tap_lock: - if kv_pairs: - emit_str(kv_pairs) - emit_str(str(pp_val(arg))) - - -def _values_to_avals(vals) -> Sequence[core.ShapedArray]: - return tuple(core.raise_to_shaped(core.get_aval(v)) for v in vals) - -### The outside_call primitive -""" -This primitive is used to implement the `call` and `id_tap` functions. -It takes several positional arguments that are the flattened -according to `arg_treedef`. -The result of the primitive is computed based on the `identity` parameter, -as follows: - - * if `identity` is True, then the results are the same as the - positional arguments of the primitive (except perhaps the last couple of - arguments, see `has_token`). In this case, `result_treedef` and - `flat_results_aval` are ignored, and `args_treedef` describes the result also. - * if `identity` is False, then the results are those from - the call to the outside computation: - - flatten(callback(arg_treedef.unflatten(args), device=...)) - - In this case, the callback results must match `result_treedef` - and `flat_results_aval`. - -It takes the following parameters: - - * callback: the function to invoke with the unflattened arguments, - the device and the transforms: `callback(arrays, device, transforms)` - * arg_treedef: the treedef for the argument. - * identity: see description above. - * result_treedef, flat_results_aval: describes the expected result of the - callback. Only used when not `identity`. - * transforms: a tuple of the transformations that have been applied. Each - element of the tuple is itself a tuple with the first element the name - of the transform. The remaining elements depend on the transform. For - example, for `batch`, the parameters are the dimensions that have been - batched, and for `mask` the logical shapes. These are unpacked by - _outside_call_run_callback before passing to the user function. - * has_token: a boolean, when True it means that the last positional argument - is the current token. In this case, the result of the primitive is - going to be the non-token positional arguments, along with the updated - token. The tokens and this parameter are added after all the JAX - transformations, just before staging XLA. - * device_index: an integer, denotes from which device the invocation is from. - Works only when using the outfeed implementation mechanism, i.e., does - not work on CPU unless --jax_host_callback_outfeed=True. -""" -outside_call_p = core.Primitive("outside_call") -outside_call_p.multiple_results = True -core.outfeed_primitives.add(outside_call_p) - - -def _outside_call_abstract_eval(*args_a: pe.AbstractValue, - identity, **params) -> Sequence[pe.AbstractValue]: - if identity: - # Do some validation here - assert "result_treedef" not in params - assert "flat_results_aval" not in params - return args_a - assert params["device_index"] is not None - assert params["result_treedef"] is not None - assert params["flat_results_aval"] is not None - flat_results_aval = params["flat_results_aval"] - if "has_token" in params and params["has_token"]: - assert len(args_a) >= 2 - return flat_results_aval + args_a[-2:] - else: - return flat_results_aval - - -outside_call_p.def_abstract_eval(_outside_call_abstract_eval) - - -def _outside_call_impl(*args, **params): - assert "has_token" not in params - if _HOST_CALLBACK_INLINE.value: - device_index = params["device_index"] - device = xb.devices()[device_index] - results = _outside_call_run_callback(args, device, send_infeed=False, **params) - return results - else: - # We use the jitted-version of the primitive even for eager execution, both - # so that we do not duplicate logic, but also so that all outfeed is received - # by the outfeed_listeners, in the same thread from a given device. If we were - # to process the tap here, it would be coming from the main thread. Also, - # even in eager execution some primitives, such as while, are compiled. - # It would be confusing to process a sequence "id_tap; while" in two - # different threads. - return dispatch.apply_primitive(outside_call_p, *args, **params) - - -outside_call_p.def_impl(_outside_call_impl) - - -def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext, - *args_op, - identity, - device_index, - flat_results_aval=(), - **params): - # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - current_token = args_op[-2] - current_itoken = args_op[-1] - - args_to_outfeed = args_op[:-2] - # Some platforms refuse to infeed empty arrays. We generate constants - # instead. - non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), - flat_results_aval)) - need_callback_results_on_device = (not identity and - len(non_empty_flat_results_aval) > 0) - send_infeed = need_callback_results_on_device - generated_infeed = False # Keep track if we emitted an infeed op - for platform in ctx.module_context.platforms: - _raise_if_using_outfeed_with_pjrt_c_api( - xb.get_backend(platform) - ) - callback_id = _register_callback( - functools.partial( - _outside_call_run_callback, - send_infeed=send_infeed, - identity=identity, - flat_results_aval=flat_results_aval, - **params)) - - outfeed_sharding = xla_client.OpSharding() - outfeed_sharding.type = xla_client.OpSharding.Type.MAXIMAL - outfeed_sharding.tile_assignment_dimensions = [1] - outfeed_sharding.tile_assignment_devices = [device_index] - - # next_token = _callback_handler_data.receiver.add_outfeed( - # comp, current_token, callback_id, args_to_outfeed, device_index) - - xla_shapes = util.flatten( - xla.aval_to_xla_shapes(aval) for aval in ctx.avals_in[:-2]) - _callback_handler_data.receiver.register_outfeed(callback_id, xla_shapes) - outfeed_header_start = 271828 # Must match kOutfeedHeaderStart in C++ - header = mlir.ir_constant(np.array([outfeed_header_start, callback_id], - dtype=np.uint32)) - header_outfeed = hlo.OutfeedOp([header], current_token, - outfeed_config=ir.StringAttr.get('')) - mlir.set_sharding(header_outfeed, outfeed_sharding) - next_token, = header_outfeed.results - data_outfeed = hlo.OutfeedOp(args_to_outfeed, next_token, - outfeed_config=ir.StringAttr.get('')) - mlir.set_sharding(data_outfeed, outfeed_sharding) - next_token, = data_outfeed.results - - - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - else: - empty_results = [ - mlir.ir_constant(np.zeros(aval.shape, aval.dtype)) - for aval in flat_results_aval - if _aval_is_empty(aval) - ] - if non_empty_flat_results_aval: - assert need_callback_results_on_device - after_outfeed_itoken = hlo.AfterAllOp([current_itoken, next_token]) - # We shard the infeed as AssignedDevice(device_index). This must match the - # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support - # this kind of sharding, we use a custom translation for infeed. - array_sharding_proto = xla_client.OpSharding() - array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL - array_sharding_proto.tile_assignment_dimensions = [1] - array_sharding_proto.tile_assignment_devices = [device_index] - - token_sharding_proto = xla_client.OpSharding() - token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED - infeed_sharding_proto = xla.tuple_sharding_proto( - [array_sharding_proto] * len(non_empty_flat_results_aval) + - [token_sharding_proto]) - - output_types = map(mlir.aval_to_ir_types, non_empty_flat_results_aval) - flat_output_types = util.flatten(output_types) - - layouts = ir.ArrayAttr.get([ - ir.ArrayAttr.get( - [mlir.i64_attr(i) - for i in range(len(aval.shape) - 1, -1, -1)]) - for aval in non_empty_flat_results_aval - ]) - infeed = hlo.InfeedOp(flat_output_types + [hlo.TokenType.get()], - after_outfeed_itoken, - infeed_config=ir.StringAttr.get(''), - layout=layouts) - mlir.set_sharding(infeed, infeed_sharding_proto) - non_empty_results = list(infeed.results[:-1]) - next_itoken = infeed.results[-1] - generated_infeed = True - results = [ - empty_results.pop(0) - if _aval_is_empty(result_aval) else non_empty_results.pop(0) - for result_aval in flat_results_aval - ] - else: - results = empty_results - next_itoken = current_itoken - - assert generated_infeed == send_infeed, ( - f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") - assert identity or len(results) == len(flat_results_aval), ( - f"got {len(results)} but expected {len(flat_results_aval)}. " - f"identity = {identity}") - return results + [next_token, next_itoken] - - -def _outside_call_lowering(ctx: mlir.LoweringRuleContext, - *args, - has_token: bool, - identity: bool, - device_index: int, - flat_results_aval=(), - **params): - """MLIR Lowering for `CustomCall`-based HCB.""" - if len(ctx.module_context.platforms) > 1: - raise NotImplementedError("multi-platform lowering for host_callback") - platform = ctx.module_context.platforms[0] - use_outfeed = _use_outfeed(platform) - if use_outfeed: - return _outside_call_outfeed_lowering( - ctx, *args, - has_token=has_token, - identity=identity, - flat_results_aval=flat_results_aval, - device_index=device_index, - **params, - ) - else: - # TODO(necula): It seems that on CPU, with custom call, the device_index - # does not work, and the callback is always run on device_index=0 - if (device_index != 0 and "cpu" in ctx.module_context.platforms): - raise ValueError( - "The device_index feature on CPU works only when using outfeed.") - - # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - assert has_token - current_token = args[-2] - current_itoken = args[-1] - assert current_token.type == hlo.TokenType.get(), "The last two arguments must be tokens" - assert current_itoken.type == hlo.TokenType.get(), "The last two arguments must be tokens" - - args_to_outfeed = args[:-2] - # TODO(necula): this is a weak attempt to get the device. This works - # inside pmap, but does not work when we just execute on a single device, - # because in such executions we always get replica_id == 0. - replica_id = hlo.ReplicaIdOp() - callback_operands = [replica_id, *args_to_outfeed] - callback_operand_avals = [ - core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]] - if identity: - callback_flat_results_aval = [] - else: - callback_flat_results_aval = [*flat_results_aval] - - def wrapped_callback(*args): - replica_id, *arrays = args - result_arrays = _outside_call_run_callback( - arrays, - xb.local_devices()[replica_id], - send_infeed=False, - # The same parameters as outside_call_p - identity=identity, - flat_results_aval=flat_results_aval, - **params) - if identity: - # For identity, we do not pass the any results back to the device - result_arrays = () - return result_arrays - - if isinstance( - ctx.module_context.axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ): - # Apply maximal sharding so pjit only executes the callback on device device_index. - sharding = xla_client.OpSharding() - sharding.type = xla_client.OpSharding.Type.MAXIMAL - sharding.tile_assignment_dimensions = [1] - sharding.tile_assignment_devices = [device_index] - else: - sharding = None - results, next_token, keep_alive = mlir.emit_python_callback(ctx, - wrapped_callback, current_token, callback_operands, - callback_operand_avals, callback_flat_results_aval, # type: ignore[arg-type] - has_side_effect=True, sharding=sharding) - _callback_handler_data.keep_alives.append(keep_alive) - # We must put the two tokens at the end - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - - assert identity or len(results) == len(flat_results_aval), ( - f"got {len(results)} but expected {len(flat_results_aval)}. " - f"identity = {identity}") - return list(results) + [next_token, next_itoken] - -mlir.register_lowering(outside_call_p, _outside_call_lowering) - -def _outside_call_run_callback( - arrays, device, *, - send_infeed=True, - # The same parameters as outside_call_p - callback, arg_treedef, - identity, result_treedef=None, flat_results_aval=None, - transforms=(), has_token=False): - """Performs the callback: - callback(arg, device, transforms) - - Called during the device computation once we have the argument, either from - an inlined callback or from an XLA computation outfeed. - - Returns the flat list of result arrays. If `send_infeed` then it will also send - the flat list of results to the device. - """ - - def _unpack_transforms(transforms) -> tuple[tuple[str, dict[str, Any]], ...]: - def _unpack_transform(name, *params): - if name == "batch": - return name, dict(batch_dims=params[0]) - elif name == "mask": - return name, dict(logical_shapes=5) - else: - assert not params, f"{name}, {params}" - return name, {} - - return tuple(_unpack_transform(*t) for t in transforms) - - try: - arg = api.tree_unflatten(arg_treedef, arrays) - unpacked_transforms = _unpack_transforms(transforms) - logger.debug( - "Outside call invoking call_func %s, device=%s, transforms=%s", - callback, device, unpacked_transforms - ) - res = callback(arg, device, unpacked_transforms) - if identity: - return tuple(arrays) - - else: # Check the type of the callback results - assert result_treedef is not None - assert flat_results_aval is not None - actual_flat_results, actual_result_treedef = tree_util.tree_flatten(res) - if actual_result_treedef != result_treedef: - msg = (f"Callback func {callback} should have returned a result " - f"with pytree {result_treedef} but returned " - f"{actual_result_treedef}") - raise TypeError(msg) - - canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results)) - actual_flat_results_aval = _values_to_avals(canonical_flat_results) - logger.debug( - "Outside call %s result %s. Sending to infeed for device %s.", - callback, flat_results_aval, device, - ) - - if not all(ea.strip_weak_type() == ra.strip_weak_type() - for ea, ra in util.safe_zip(flat_results_aval, - actual_flat_results_aval)): - msg = (f"Callback func {callback} should have returned a result " - "with abstract values " - f"{result_treedef.unflatten(flat_results_aval)} " - f"but returned {actual_result_treedef.unflatten(actual_flat_results_aval)}") - raise TypeError(msg) - - if send_infeed: - # Do not send the 0-sized arrays - non_empty_canonical_flat_results = tuple(filter(lambda r: not _aval_is_empty(r), - canonical_flat_results)) - device.transfer_to_infeed(non_empty_canonical_flat_results) - return canonical_flat_results - - except Exception as e: - logger.error("Outside call %s threw exception %s.", callback, e) - if send_infeed: - # Prepare some results to send in case of error. We are sending something - # with a distinctive shape (int8[12345]), one that is unlikely to be what the device - # expects. This should have the effect to abort the device computation, - # with an error message that we recognize. On TPU there seem to be no - # such check, and if we send anything at all the device computation will - # use some garbage data. So, on TPU we prefer to not send anything and let - # the computation hang. - # TODO: implement a proper error handling for TPU - if device.platform != "tpu": - canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))] - logger.debug("Outside call consumer %s exception %s. Sending to infeed the error result.", - callback, e) - device.transfer_to_infeed(tuple(canonical_flat_results)) - else: - logger.debug("Outside call consumer %s exception %s. On TPU we do not send infeed.", - callback, e) - raise e # Let the exception propagate - - -def _add_transform(params: dict, name: str, *transform_params) -> dict: - """Adds the `transform` to the params["transforms"]. - - Uses a tuple representation internally, will be unpacked before the - callback by _ConsumerCallable. - """ - new_transform = (name, *transform_params) - return dict( - params, transforms=(params.get("transforms", ()) + (new_transform,))) - - -def _aval_is_empty(aval) -> bool: - return math.prod(aval.shape) == 0 - -def _instantiate_zeros(tan, arg): - del arg - return ad.instantiate_zeros(tan) - -def _outside_call_jvp_rule(primals, tangents, **params): - assert "has_token" not in params - if not params["identity"]: - raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") - out_primals_tapped = outside_call_p.bind(*primals, **params) - return tuple(out_primals_tapped), tangents - - -ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule - -def _outside_call_transpose_rule(cts, *args, **params): - if not params["identity"]: - raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.") - assert "has_token" not in params - assert len(cts) == len(args) - cts_instantiated = tuple(map(_instantiate_zeros, cts, args)) - - # The args have been prepared by the id_tap_jvp_rule: tapped_primals, tapped_tangents, rest_primals, rest_tangents - transforms = params.get("transforms", ()) - if not transforms or transforms[-1] != ("jvp",): - # TODO: I should understand better when can this happen. It seems to arise - # in scan. - return outside_call_p.bind( - *cts_instantiated, - **_add_transform(params, "transpose")) - - assert False - - -ad.primitive_transposes[outside_call_p] = _outside_call_transpose_rule - - -def _outside_call_batching_rule(batched_args, batch_dims, **params): - if not params["identity"]: - raise NotImplementedError("batching rules are implemented only for id_tap, not for call.") - assert "has_token" not in params - new_params = _add_transform(params, "batch", batch_dims) - res = outside_call_p.bind(*batched_args, **new_params) - return res, batch_dims - - -batching.primitive_batchers[outside_call_p] = _outside_call_batching_rule - -#### -#### Jaxpr rewriting logic to thread the tokens through stateful primitives. -#### - - -def _rewrite_closed_jaxpr(cjaxpr: core.ClosedJaxpr, has_input_token: bool, - has_output_token: bool) -> core.ClosedJaxpr: - """Rewrites a ClosedJaxpr to thread the token, if needed.""" - new_jaxpr = _rewrite_jaxpr(cjaxpr.jaxpr, has_input_token, has_output_token) - return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts) - - -def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, - has_output_token: bool) -> core.Jaxpr: - """Rewrite a Jaxpr to thread the token, if needed.""" - assert has_input_token or not has_output_token - - if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr): - return jaxpr - - mk_new_var = core.gensym() - - eqns: list[core.JaxprEqn] = [] - # store the incoming tokens - last_token_var = mk_new_var(core.abstract_token) - last_itoken_var = mk_new_var(core.abstract_token) - if has_input_token: - invars = jaxpr.invars + [last_token_var, last_itoken_var] - else: - invars = jaxpr.invars - # We need tokens but none is given in input; make one depending on all invars - eqns.append( - core.new_jaxpr_eqn(jaxpr.invars, [last_token_var], - lax.create_token_p, {}, core.no_effects, source_info_util.current())) - eqns.append( - core.new_jaxpr_eqn(jaxpr.invars, [last_itoken_var], - lax.create_token_p, {}, core.no_effects, source_info_util.current())) - - for eqn in jaxpr.eqns: - if not core.primitive_uses_outfeed(eqn.primitive, eqn.params): - eqns.append(eqn) - else: - output_token_var = mk_new_var(last_token_var.aval) - output_itoken_var = mk_new_var(last_itoken_var.aval) - _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, - last_itoken_var, output_itoken_var, mk_new_var) - last_token_var = output_token_var - last_itoken_var = output_itoken_var - - outvars = jaxpr.outvars + ([last_token_var, last_itoken_var] if has_output_token else []) - new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr.effects) - return new_jaxpr - - -def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], - input_token_var: core.Var, output_token_var: core.Var, - input_itoken_var: core.Var, output_itoken_var: core.Var, - mk_new_var: Callable[[core.AbstractValue], core.Var]): - """Rewrite an `eqn` and append equations to `eqns`. - - This is only called if the current primitive uses outfeed. - Assume that the current token is in `input_token_var` and the resulting - token must end in `output_token_var`. - - Append the result of rewriting to `eqns`. - """ - if eqn.primitive is outside_call_p: - assert "has_token" not in eqn.params - eqns.append(eqn.replace(invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict(eqn.params, has_token=True))) - elif eqn.primitive is lax.while_p: - cond_jaxpr, _, body_jaxpr, _ = util.split_dict( - eqn.params, - ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) - if core.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): - _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, - input_itoken_var, output_itoken_var, - mk_new_var) - return - - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True), - cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True, False)))) - elif eqn.primitive is lax.cond_p: - branches, = util.split_dict(eqn.params, ["branches"]) - index, *operands = eqn.invars - new_invars = [index, *operands, input_token_var, input_itoken_var] - eqns.append( - eqn.replace( - invars=new_invars, outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - branches=tuple( - _rewrite_closed_jaxpr(jaxpr, True, True) - for jaxpr in branches)))) - elif eqn.primitive is lax.scan_p: - num_consts, num_carry, carry_jaxpr, linear, _, _, _, _ = util.split_dict( - eqn.params, - ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length", - "unroll", "_split_transpose"]) - # We add the tokens right at the end of carry - nr_const_and_carry = num_consts + num_carry - new_invars = eqn.invars[0:nr_const_and_carry] + [ - input_token_var, input_itoken_var] + eqn.invars[nr_const_and_carry:] - new_jaxpr = _rewrite_closed_jaxpr(carry_jaxpr, True, True) - # The rewrite has put the token at end, it has to be at end of carry - new_jaxpr_invars = new_jaxpr.jaxpr.invars - new_jaxpr_invars = ( - new_jaxpr_invars[0:nr_const_and_carry] + new_jaxpr_invars[-2:] + - new_jaxpr_invars[nr_const_and_carry:-2]) - new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(invars=new_jaxpr_invars)) - - new_jaxpr_outvars = new_jaxpr.jaxpr.outvars - new_jaxpr_outvars = ( - new_jaxpr_outvars[0:num_carry] + new_jaxpr_outvars[-2:] + - new_jaxpr_outvars[num_carry:-2]) - new_jaxpr = new_jaxpr.replace(jaxpr=new_jaxpr.jaxpr.replace(outvars=new_jaxpr_outvars)) - eqns.append( - eqn.replace( - invars=new_invars, - # Output token is at the end of carry result - outvars=(eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] + - eqn.outvars[num_carry:]), - params=dict( - eqn.params, - jaxpr=new_jaxpr, - num_carry=num_carry + 2, - linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]))) - elif eqn.primitive is pxla.xla_pmap_p: - # We broadcast the input token into an array of tokens - call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - donated_invars=eqn.params["donated_invars"] + (False, False), - # Sharding/unsharding of tokens in pmap_translation are special - # cased to just pass-through the token - in_axes=eqn.params["in_axes"] + (None, None), - out_axes=eqn.params["out_axes"] + (0, 0)))) - elif eqn.primitive is custom_derivatives.custom_jvp_call_p: - fun_jaxpr = eqn.params["call_jaxpr"] - - def unreachable_thunk(): - assert False, "Should not be reached" - unreachable_thunk.reset_stores = lambda: None - - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True), - jvp_jaxpr_thunk=unreachable_thunk - ))) - elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p: - fun_jaxpr = eqn.params["fun_jaxpr"] - new_invars = [*eqn.invars, input_token_var, input_itoken_var] - - def unreachable_thunk(): - assert False, "Should not be reached" - - eqns.append( - eqn.replace( - invars=new_invars, - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True), - fwd_jaxpr_thunk=unreachable_thunk, - # The following are illegal values for the parameters, they - # should not be needed because this rewrite is just before - # compilation to XLA, which does not use those parameters. - bwd="illegal param", - out_trees="illegal param"))) - elif eqn.primitive is pjit.pjit_p: - jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True), - donated_invars=eqn.params["donated_invars"] + (False, False), - in_shardings=( - eqn.params["in_shardings"] - + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) - ), - out_shardings=( - eqn.params["out_shardings"] - + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) - ), - in_layouts=(eqn.params["in_layouts"] + (None, None)), - out_layouts=(eqn.params["out_layouts"] + (None, None)), - ), - ) - ) - elif eqn.primitive is ad_checkpoint.remat_p: - jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - jaxpr=_rewrite_jaxpr(jaxpr_, True, True), - ))) - else: - raise NotImplementedError(f"outfeed rewrite {eqn.primitive}") - - -def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], - input_token_var: core.Var, - output_token_var: core.Var, - input_itoken_var: core.Var, - output_itoken_var: core.Var, - mk_new_var: Callable): - """Rewrite a while whose cond has outfeed""" - cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( - eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) - transformed_cond_jaxpr = _rewrite_closed_jaxpr(cond_jaxpr, True, True) - carry_invars = eqn.invars[cond_nconsts + body_nconsts:] - # pred1, token1, itoken1 = rewrite(COND)(cond_consts, carry_invars, input_token, input_itoken) - pred1_and_token1 = [ - mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars - ] - eqns.append( - core.new_jaxpr_eqn( - eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var], - pred1_and_token1, core.call_p, - dict( - call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_before"), - transformed_cond_jaxpr.jaxpr.effects, - eqn.source_info)) - # Make a new cond "lambda pred, carry, token, itoken: pred" - new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0]) - new_cond_invars = ( - [new_cond_pred_invar] + [mk_new_var(cv.aval) for cv in carry_invars] + - [mk_new_var(input_token_var.aval), - mk_new_var(input_itoken_var.aval)]) - new_cond_jaxpr = core.ClosedJaxpr( - core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], [], set()), []) - # Make a new body: - # "lambda cond_constvars, body_constvars, pred, carry, token, itoken: - # carry2, token2, itoken2 = rewrite(BODY)(body_constvars, carry, token, itoken) - # pred2, token3, itoken3 = rewrite(COND)(cond_constvars, carry2, token2, itoken2) - # (pred2, carry2, token3, itoken3) - transformed_body_jaxpr = _rewrite_closed_jaxpr(body_jaxpr, True, True) - new_body_invars_cond_constvars = [ - mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts] - ] - new_body_invars_body_constvars = [ - mk_new_var(v.aval) - for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts] - ] - new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0]) - new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars] - new_body_invars_token = mk_new_var(input_token_var.aval) - new_body_invars_itoken = mk_new_var(input_itoken_var.aval) - - new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars] - new_body_token2 = mk_new_var(input_token_var.aval) - new_body_itoken2 = mk_new_var(input_itoken_var.aval) - new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0]) - new_body_token3 = mk_new_var(input_token_var.aval) - new_body_itoken3 = mk_new_var(input_itoken_var.aval) - - new_body_eqns = [ - core.new_jaxpr_eqn( - new_body_invars_body_constvars + new_body_invars_carry + - [new_body_invars_token, new_body_invars_itoken], - new_body_carry2 + [new_body_token2, new_body_itoken2], - core.call_p, - dict( - call_jaxpr=transformed_body_jaxpr.jaxpr, - name="body"), - transformed_body_jaxpr.effects, - eqn.source_info), - core.new_jaxpr_eqn( - new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2], - [new_body_pred2, new_body_token3, new_body_itoken3], core.call_p, - dict( - call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_body"), - transformed_cond_jaxpr.effects, - eqn.source_info) - ] - effects = core.join_effects(*(eqn.effects for eqn in new_body_eqns)) - new_body_jaxpr = core.ClosedJaxpr( - core.Jaxpr([], (new_body_invars_cond_constvars + - new_body_invars_body_constvars + [new_body_invars_pred] + - new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken]), - ([new_body_pred2] + new_body_carry2 + [new_body_token3, new_body_itoken3]), - new_body_eqns, effects), []) - - pred_out = mk_new_var(cond_jaxpr.out_avals[0]) - eqns.append( - core.new_jaxpr_eqn( - (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] + - carry_invars + pred1_and_token1[1:]), - ([pred_out] + eqn.outvars + [output_token_var, output_itoken_var]), - lax.while_p, - dict( - cond_jaxpr=new_cond_jaxpr, - cond_nconsts=0, - body_jaxpr=new_body_jaxpr, - body_nconsts=cond_nconsts + body_nconsts), - new_body_jaxpr.effects, - eqn.source_info)) - - -# We need an identity primitive to simplify rewriting -id_p = core.Primitive("id") -id_p.multiple_results = True -id_p.def_impl(lambda *args: args) -id_p.def_abstract_eval(lambda *args: args) -mlir.register_lowering(id_p, lambda ctx, *args: args) - -dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False) - - -class CallbackException(Exception): - """Signals that some callback function had exceptions. - - Raised by :func:`barrier_wait`. See the :mod:`jax.experimental.host_callback` - module documentation for details. - """ - pass - -TapFunctionException = CallbackException # For backwards compatibility - -class _CallbackHandlerData: - """Keep track of the outfeed receiver data.""" - receiver: Any - initialized: bool - on_exit: bool - lock: threading.Lock - last_callback_exception: tuple[Exception, str] | None - clients: tuple[XlaLocalClient, ...] - devices: tuple[XlaDevice, ...] - consumer_registry: dict[Callable, int] - consumer_registry_by_id: dict[int, Callable] - - def __init__(self): - self.receiver = None # Initialize lazily, when first needed - self.initialized = False - self.on_exit = False - self.lock = threading.Lock() - self.last_callback_exception = None - self.clients = () - self.devices = () - # The consumer registries must be live for the lifetime of the program, - # because we may have cached compilations that embed consumer ids, and we - # do not want the id reused for other shapes. - # Used only for the outfeed mechanism. - self.callback_registry = {} - self.callback_registry_by_id = {} - # For now we keep here the keep_alives for the emit_python_callback. This is - # a leak. We ought to attach these to the executable. - self.keep_alives = [] - - def stop(self): - """Wait for all pending outfeeds and stop the receiver.""" - self.receiver = None # GC will trigger the destructor - self.initialized = False - self.clients = () - self.devices = () - # Do not clear the consumer registries. - - -_callback_handler_data = _CallbackHandlerData() - - -# This function is called from C++; it must not allow exceptions through. -def _callback_input_received(device, consumer_id, arrays: tuple): - array_repr = ", ".join([f"({a.dtype}{a.shape})" for a in arrays]) - logger.debug("Callback input received on device %s for consumer %s arrays: %s", - device, consumer_id, array_repr) - callback = _callback_handler_data.callback_registry_by_id.get(consumer_id) - assert callback is not None, "We should have crashed in the runtime" - try: - return callback(arrays, device) - except Exception as e: - formatted_e = traceback.format_exc() - logger.error("Postponing exception raised in callback function: %s", formatted_e) - _callback_handler_data.last_callback_exception = (e, formatted_e) - - -def _register_callback(callback: Callable) -> int: - """Registers a callback function, cache by hash of callback. - - The callback is a function to be invoked as `callback(arrays, device)`. - """ - callback_id = _callback_handler_data.callback_registry.get(callback) - if callback_id is not None: - return callback_id - callback_id = hash(callback) & 0xFFFFFFFC # pybind11 has trouble here with large ints - callback_id += 1 # Reserve the consumer ID 0 - assert callback_id not in _callback_handler_data.callback_registry, ( - "callback id collision") - _callback_handler_data.callback_registry[callback] = callback_id - _callback_handler_data.callback_registry_by_id[callback_id] = callback - return callback_id - - -def _initialize_outfeed_receiver( - max_callback_queue_size_bytes: int = int(256 * 1e6)): - """Creates and starts the outfeed_receiver. - - This function is called lazily only when we compile an id_tap. - - Args: - * clients: the list of clients (backends) on whose devices to listen on. - * max_callback_queue_size_bytes: an optional integer to bound the maximum - size of arrays in the callback queue. When this limit is reached the - device listener pauses. - """ - outfeed_receiver_module = xla_extension.outfeed_receiver - - with _callback_handler_data.lock: - if _callback_handler_data.initialized: - return - - # By default, all devices on all supported backends. - clients = [backend for name, backend in xb.backends().items() - if name in ("cpu", "cuda", "rocm", "tpu")] - devices = list( - itertools.chain(*[backend.local_devices() for backend in clients])) - _callback_handler_data.clients = clients # type: ignore[assignment] - _callback_handler_data.devices = devices # type: ignore[assignment] - clients_with_outfeed = [c for c in clients if _use_outfeed(c.platform)] - for client in clients_with_outfeed: - _raise_if_using_outfeed_with_pjrt_c_api(client) - if clients_with_outfeed: - devices_with_outfeed = list( - itertools.chain(*[backend.local_devices() for backend in clients_with_outfeed])) - if logger.isEnabledFor(logging.DEBUG): - device_repr = ", ".join([str(d) for d in devices_with_outfeed]) - logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s", - device_repr, max_callback_queue_size_bytes) - _callback_handler_data.receiver = outfeed_receiver_module.start( - _callback_input_received, tuple(clients_with_outfeed), - max_callback_queue_size_bytes, - compiler.get_compile_options(1, 1).executable_build_options) - - def exit_handler(): - # Prevent logging usage during compilation, gives errors under pytest - dispatch._on_exit = True - if not _callback_handler_data.on_exit: - _callback_handler_data.on_exit = True - _deprecated_barrier_wait("at_exit") - - atexit.register(exit_handler) # We wait as long as we have callbacks - _callback_handler_data.initialized = True - - -def _deprecated_barrier_wait(logging_name: str | None = None): - """Blocks the calling thread until all current outfeed is processed. - - Waits until all callbacks from computations already running on all devices - have been received and processed by the Python callbacks. Raises - CallbackException if there were exceptions while processing the callbacks. - - This works by enqueueing a special tap computation to all devices to which - we are listening for outfeed. Once all those tap computations are done, we - return from barrier_wait. - - Note: If any of the devices are busy and cannot accept new computations, - this will deadlock. - - Args: - logging_name: an optional string that will be used in the logging statements - for this invocation. See `Debugging` in the module documentation. - - For more details see the :mod:`jax.experimental.host_callback` module documentation. - """ - if not _HOST_CALLBACK_LEGACY.value: - jax.effects_barrier() - return - - logging_name = logging_name or "" - logger.debug("barrier_wait[%s]: start", logging_name) - - lock = threading.Lock() - cv = threading.Condition(lock=lock) - devices_at_barrier = [] # Protected by lock - def barrier_tap_received(dev_idx, _): - device = _callback_handler_data.devices[dev_idx] - logger.debug( - "barrier_wait[%s]: at barrier_tap for device %s. Thread %s", - logging_name, device, threading.current_thread() - ) - with lock: - devices_at_barrier.append(device) - if logger.isEnabledFor(logging.DEBUG): - waiting_for_devices = [d for d in _callback_handler_data.devices - if d not in devices_at_barrier] - logger.debug( - "barrier_wait[%s]: still waiting for %s devices at barrier (%s)", - logging_name, len(waiting_for_devices), waiting_for_devices - ) - cv.notify() - - for d_idx, d in enumerate(_callback_handler_data.devices): - logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d) - x_on_dev = api.device_put(d_idx, device=d) - api.jit(lambda x: _deprecated_id_tap(barrier_tap_received, x), device=d)(x_on_dev) - - logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name) - - with lock: - cv.wait_for(lambda: len(devices_at_barrier) == len(_callback_handler_data.devices)) - - logger.debug("barrier_wait[%s]: done", logging_name) - - if _callback_handler_data.last_callback_exception is not None: - last_exception, formatted_last_exception = _callback_handler_data.last_callback_exception - _callback_handler_data.last_callback_exception = None - raise CallbackException( - "There were exceptions during callback processing. " - f"Last one was: {formatted_last_exception}") from last_exception - - -def _deprecated_stop_outfeed_receiver(): - """Stops the outfeed receiver runtime. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. + warnings.warn("""The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the - `new JAX external callbacks `_ - - This waits for all outfeeds from computations already running on all devices, - and then stops the outfeed receiver runtime. The runtime will be restarted - next time you use a tap function. - - It should not be necessary to use this function, unless you want to start - using lax.outfeed directly after having used host callbacks. - """ - _callback_handler_data.stop() - -_deprecation_msg = ( - "The host_callback APIs are deprecated as of March 20, 2024. The functionality " - "is subsumed by the new JAX external callbacks. " - "See https://github.com/jax-ml/jax/issues/20385.") - -_deprecations = { - # Added March 20, 2024 - "id_tap": (_deprecation_msg, _deprecated_id_tap), - "id_print": (_deprecation_msg, _deprecated_id_print), - "call": (_deprecation_msg, _deprecated_call), - "barrier_wait": (_deprecation_msg, _deprecated_barrier_wait), - "stop_outfeed_receiver": (_deprecation_msg, _deprecated_stop_outfeed_receiver), -} + new JAX external callbacks (https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + See https://github.com/jax-ml/jax/issues/20385 + """, DeprecationWarning, stacklevel=2) + if callback_flavor is not None: + raise NotImplementedError( + "host_callback.call is only supported with the IO_CALLBACK flavor.") + if call_with_device: + raise NotImplementedError( + "host_callback.call is only supported with the call_with_device=False.") + callback_device = jax.local_devices()[device_index] + sharding = jax.sharding.SingleDeviceSharding(callback_device) + return io_callback(callback_func, result_shape, arg, + sharding=sharding, + ordered=True) import typing if typing.TYPE_CHECKING: - id_tap = _deprecated_id_tap - id_print = _deprecated_id_print - call = _deprecated_call - barrier_wait = _deprecated_barrier_wait - stop_outfeed_receiver = _deprecated_stop_outfeed_receiver -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr + def id_tap(tap_func, + arg, + *, + result=None, + tap_with_device=False, + device_index=0, + callback_flavor=None, + **kwargs): + raise NotImplementedError( + "host_callback.id_tap is no longer supported. " + "See https://github.com/jax-ml/jax/issues/20385" + ) + del typing diff --git a/tests/BUILD b/tests/BUILD index 7ab6cc136e97..188b5ae814d7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1145,29 +1145,13 @@ jax_multiplatform_test( jax_multiplatform_test( name = "host_callback_test", srcs = ["host_callback_test.py"], - args = ["--jax_host_callback_outfeed=false"], main = "host_callback_test.py", - shard_count = { - "gpu": 5, - }, - tags = ["noasan"], # Times out deps = [ "//jax:experimental", - "//jax:experimental_host_callback", "//jax:ode", ], ) -jax_multiplatform_test( - name = "host_callback_to_tf_test", - srcs = ["host_callback_to_tf_test.py"], - tags = ["noasan"], # Linking TF causes a linker OOM. - deps = [ - "//jax:experimental_host_callback", - "//jax:ode", - ] + py_deps("tensorflow_core"), -) - jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index b23b4c4e7a41..42c4496643bf 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -14,1438 +14,20 @@ from __future__ import annotations -import contextlib -from collections.abc import Callable -from functools import partial -import itertools -import logging -import os -import re -import time -import unittest from unittest import SkipTest from absl.testing import absltest import jax -from jax import ad_checkpoint -from jax import dtypes -from jax import lax -from jax import numpy as jnp from jax.experimental import host_callback as hcb -from jax._src import core from jax._src import xla_bridge from jax._src import test_util as jtu -from jax._src.lib import xla_client - -from jax.experimental.host_callback import _deprecated_id_print as hcb_id_print - -xops = xla_client.ops import numpy as np jax.config.parse_flags_with_absl() -class _TestingOutputStream: - """Use as `output_stream` for tests.""" - - def __init__(self): - self._output = [] - self._test_method_name = None - - def write(self, what: str) -> None: - logging.info(f"output_stream[{self._test_method_name}]: {what}") - self._output.append(what) - - @property - def output(self): - return "".join(self._output) - - @property - def output_sorted_by_device(self): - # Assume that the output is a sequence of strings including metadata - # and data, with metadata containing `device: xxx` - by_device = [] # each element is a pair (device, str_list) - for s in self._output: - m = re.match(r".*device: (\S+)", s) - if m: - by_device.append((m.group(1), [])) - assert by_device, f"output does not include 'device:': {self._output}" - by_device[-1][1].append(s) - - sorted_by_device = sorted(by_device, key=lambda x: x[0]) - return "\n".join(itertools.chain(*[s[1] for s in sorted_by_device])) - - def __str__(self): - return "TestingOutputStream" - - def reset(self): - self._output = [] - - -testing_stream = _TestingOutputStream() - - -def fun1(a): - """Function used for several `id_tap` tests.""" - y = hcb_id_print(a * 2., what="a * 2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - y = hcb_id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return y ** 2 # Some computation to make the gradient interesting - - -def fun1_equiv(a): # Numerical equivalent of fun1 - return (a * 2.) ** 2 - - -def maybe_print(do_print: bool, - arg, - what: str, - tap_with_device: bool | None = False, - device_index: int = 0): - """Conditionally print on testing_string""" - if do_print: - return hcb_id_print( - arg, - what=what, - output_stream=testing_stream, - tap_with_device=tap_with_device, - device_index=device_index) - else: - return arg - - -def local_devices(): - # Tests require using not more than 2 devices. - return jax.local_devices()[:2] - - -ignore_jit_of_pmap_warning = partial( - jtu.ignore_warning, message=".*jit-of-pmap.*") - - -def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, - expected: str, what: str): - """A variant that preprocesses the string to eliminate non-determinism in - floating point values, and several uninteresting id_tap primitive params. - """ - - # Sometimes we get floating points in the output; we round them - def repl_floats(match_group): - matched = match_group.group(0) - if matched == ".": return matched - x = np.around(float(matched), decimals=2) - return f"{x:.2f}" - - what = re.sub(r"\-?\d+\.[\-\def]*", repl_floats, what) - what = re.sub(r"output_stream=[^\]\n,]*,?", "", what) - what = re.sub(r"threshold=[^\]\n,]*,?", "", what) - what = re.sub(r"bwd=[^\]\n]*", "", what) - what = re.sub(r"out_trees=[^\]\n]*", "", what) - what = re.sub(r"fwd_jaxpr_thunk=[^\]\n]*", "", what) - what = re.sub(r"jvp_jaxpr_thunk=[^\]\n]*", "", what) - # Empty lines - what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE) - - def repl_func(match_group): - matched = match_group.group(3) - if "function _print_consumer" in matched: - return match_group.group(1) + "=_print" - else: - return match_group.group(1) + "=..." - - what = re.sub(r"((tap_func_)|(callback))=([^\]\n,]*),?", repl_func, what) - tst.assertMultiLineStrippedEqual(expected, what) - - -def helper_set_hlo_dump(): - flags_str = os.getenv("XLA_FLAGS", "") - import shutil - dump_dir = "/tmp/xla_dump" - os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to={dump_dir}" - if os.path.isdir(dump_dir): - logging.warning("Deleting old XLA dump directory %s", dump_dir) - shutil.rmtree(dump_dir) - logging.warning("Setting XLA dump directory %s", dump_dir) - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - - -def helper_print_optimized_hlo(fun, *args): - backend = xla_bridge.get_backend(platform=jtu.device_under_test()) - c = jax.jit(fun, backend=backend.platform).lower(*args) - logging.info(re.sub(r", metadata.*", "", c.compile().as_text())) - - -def helper_log_ir(name, - f_jax, - *args, - num_partitions=None, - strip_metadata=False): - logging.info(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}") - jax_comp = f_jax.lower(*args) - logging.info(f"HLO[{name}]: {jax_comp.compiler_ir(dialect='hlo').as_hlo_text()}") - jax_optimized_hlo = jax_comp.compile().as_text() - if strip_metadata: - jax_optimized_hlo = re.sub(r", metadata.*", "", jax_optimized_hlo) - logging.info(f"Optimized HLO[{name}]: {jax_optimized_hlo}") - - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - - -def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase, - expected_2CPUs: str): - """Check that the multi-device output is equal to the expected. - - The tests run with 2 devices if available, otherwise 1 device. - We adjust the expected output here for 1 device. - - Args: - expected_2CPUs: the expected output for 2 CPUs. If there is only - one device, this is trimmed to the first device. If the current - device_under_test is not a CPU, then we change the names - """ - expected = expected_2CPUs - if len(local_devices()) == 1: - start_device_1 = expected.find('device: cpu:1') - if start_device_1 >= 0: - expected = expected[0:start_device_1] - - def replace_device_name(m) -> str: - return str(local_devices()[int(m.group(1))]) - - expected = re.sub(r'cpu:(\d+)', replace_device_name, expected) - what = testing_stream.output_sorted_by_device - return assertMultiLineStrippedEqual(tst, expected, what) - - -class HostCallbackImportsTest(jtu.JaxTestCase): - @jtu.ignore_warning( - category=DeprecationWarning, - message="The host_callback APIs are deprecated") - def test_deprecated_imports(self): - if hasattr(hcb, "id_print"): - id_print = hcb.id_print - self.assertIs(id_print, hcb_id_print) - -class HostCallbackTapTest(jtu.JaxTestCase): - - def setUp(self): - # skipping here skips teardown, so do this before super().setUp(). - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="backend and device argument")) - testing_stream.reset() - testing_stream._test_method_name = self._testMethodName - self.old_flags = os.getenv("XLA_FLAGS", "") - - def tearDown(self) -> None: - if os.getenv("XLA_FLAGS") != self.old_flags: - os.environ["XLA_FLAGS"] = self.old_flags - xla_bridge.get_backend.cache_clear() - jax.effects_barrier() - super().tearDown() - - def test_tap_eval(self): - self.assertAllClose((5. * 2.) ** 2, fun1(5.)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: a * 2 - 10.00 - what: y * 3 - 30.00""", testing_stream.output) - - def test_tap_with_tuple_results(self): - def func2(x): - x1, y1 = hcb_id_print((x * 2., x * 3.), output_stream=testing_stream) - return x1 + y1 - - self.assertEqual(3. * (2. + 3.), func2(3.)) - jax.effects_barrier() - - assertMultiLineStrippedEqual(self, """ - ( 6.00 9.00 )""", testing_stream.output) - - def test_tap_with_dict_results(self): - def func2(x): - res = hcb_id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) - return res["a"] + res["b"] - - self.assertEqual(3. * (2. + 3.), func2(3.)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - { a=6.00 b=9.00 }""", testing_stream.output) - - def test_tap_with_result(self): - def func2(x): - x1 = hcb_id_print((x * 2., x * 3.), result=x * 4., - output_stream=testing_stream) - return x1 - - self.assertEqual(3. * 4., func2(3.)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - ( 6.00 9.00 )""", testing_stream.output) - - def test_tap_with_result_no_arg(self): - def tap_func(arg, transforms): - testing_stream.write(f"called tap_func with {arg}") - - def func2(x): - x1 = hcb.id_tap(tap_func, None, result=x) - return x1 - - self.assertEqual(3., func2(3.)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, "called tap_func with None", - testing_stream.output) - - def test_tap_result_unused(self): - def tap_func(arg, transforms): - testing_stream.write(f"called tap_func with {arg}") - def func2(x): - hcb.id_tap(tap_func, None) - return x - - self.assertEqual(3., func2(3.)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, "called tap_func with None", - testing_stream.output) - - def test_tap_empty(self): - """Tap empty arrays.""" - hcb_id_print((), output_stream=testing_stream) - hcb_id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - ( ) - what: second - ( 1.00 [] )""", testing_stream.output) - - def test_tap_jit_simple(self): - jit_fun1 = jax.jit(lambda x: 3. * hcb_id_print( - 2. * x, what="here", output_stream=testing_stream)) - self.assertAllClose(6. * 5., jit_fun1(5.)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: here - 10.00""", testing_stream.output) - - def test_tap_jit_no_invars(self): - def func(): # jitted function does not take arguments - return hcb_id_print(42, output_stream=testing_stream) - - self.assertAllClose(42, jax.jit(func)()) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - 42""", testing_stream.output) - - def test_tap_jit_multiple_invars(self): - def func(x1, x2): - return hcb_id_print(x1 + x2, output_stream=testing_stream) - - self.assertAllClose(42, jax.jit(func)(40, 2)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - 42""", testing_stream.output) - - def test_tap_jit_constant(self): - def func(x): - return hcb_id_print(42, result=x, output_stream=testing_stream) - - self.assertAllClose(5, jax.jit(func)(5)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - 42""", testing_stream.output) - - def test_tap_jit_sequence1(self): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - return hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - logging.info("%s: %s", self._testMethodName, - jax.make_jaxpr(func)(1)) - logging.info( - "%s: %s", - self._testMethodName, - jax.jit(func) - .trace(1) - .lower(lowering_platforms=(jtu.device_under_test(),)).as_text("hlo")) - self.assertEqual(2, jax.jit(func)(1)) - jax.effects_barrier() - - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2""", testing_stream.output) - - def test_tap_jit2(self): - """A sequence of JIT.""" - - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - return x2 - - self.assertEqual(2, jax.jit(func)(1)) - self.assertEqual(11, jax.jit(func)(10)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: 1 - 10 - where: 2 - 11""", testing_stream.output) - - def test_tap_jit_result_unused(self): - """We can id_print even if we don't use the result.""" - - def func(x): - hcb_id_print(x, where="1", output_stream=testing_stream) - hcb_id_print(x + 1, where="2", output_stream=testing_stream) - return x + 1 - - self.assertEqual(2, jax.jit(func)(1)) - self.assertEqual(11, jax.jit(func)(10)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: 1 - 10 - where: 2 - 11""", testing_stream.output) - - def test_tap_jit_nested(self): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - - def func_nested(x): - x2 = hcb_id_print(x + 1, where="nested", output_stream=testing_stream) - return x2 - - x3 = jax.jit(func_nested)(x1) - return hcb_id_print(x3 + 1, where="3", output_stream=testing_stream) - - self.assertEqual(3, jax.jit(func)(1)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: nested - 2 - where: 3 - 3""", testing_stream.output) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_pytree(self, with_jit=False): - def func(x, what=""): - """Returns some pytrees depending on x""" - if what == "pair_1_x": - return (1, x) - elif what == "pair_x_2x": - return (x, 2 * x) - elif what == "dict": - return dict(a=2 * x, b=3 * x) - else: - assert False - - tap_count = 0 - - def tap_func(a, _, *, what=""): - nonlocal tap_count - tap_count += 1 - self.assertEqual(func(5, what), a) - - transform = jax.jit if with_jit else lambda f: f - for what in ("pair_1_x", "pair_x_2x", "dict"): - transformed = transform( - lambda x: hcb.id_tap( - partial(tap_func, what=what), - func(x, what), - result=func(x * 2, what)) - )(5) - self.assertEqual(func(10, what), transformed) - jax.effects_barrier() # Wait for receivers to be done - self.assertEqual(3, tap_count) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_cond(self, with_jit=False): - """A conditional""" - - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - x4 = lax.cond(x % 2 == 0, - lambda x: hcb_id_print(x, where="cond_t", - output_stream=testing_stream), - lambda x: hcb_id_print(-1, where="cond_f", result=x, - output_stream=testing_stream), - x2 + 1) - x5 = hcb_id_print(x4 + 1, where="end", output_stream=testing_stream) - return x5 - - transform = jax.jit if with_jit else lambda f: f - self.assertEqual(4, transform(func)(1)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: cond_f - -1 - where: end - 4""", testing_stream.output) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_while_cond(self, with_jit=False): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - def body(x): - x3 = hcb_id_print(x, where="w_b_1", output_stream=testing_stream) - x4 = lax.cond(x % 2 == 0, - lambda x: hcb_id_print(x, where="w_b_t", - output_stream=testing_stream), - lambda x: hcb_id_print(-1, where="w_b_f", - result=x, output_stream=testing_stream), - x3 + 1) - return hcb_id_print(x4, where="w_b_2", output_stream=testing_stream) - - x10 = lax.while_loop(lambda x: x <= 3, body, x2) - res = hcb_id_print(x10, where="end", output_stream=testing_stream) - return res - - transform = jax.jit if with_jit else lambda f: f - self.assertEqual(4, transform(func)(1)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: w_b_1 - 2 - where: w_b_t - 3 - where: w_b_2 - 3 - where: w_b_1 - 3 - where: w_b_f - -1 - where: w_b_2 - 4 - where: end - 4""", testing_stream.output) - - def test_tap_jit_while_pred_tap(self): - """While with printing in the conditional.""" - - def func(x): - x1 = hcb_id_print(x, where="1") - x10 = lax.while_loop(lambda x: hcb_id_print(x < 3, - where="w_p", - output_stream=testing_stream), - lambda x: hcb_id_print(x + 1, where="w_b", - output_stream=testing_stream), - x1) - res = hcb_id_print(x10, where="3", output_stream=testing_stream) - return res - - self.assertEqual(3, jax.jit(func)(1)) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, - """ - where: w_p - True - where: w_b - 2 - where: w_p - True - where: w_b - 3 - where: w_p - False - where: 3 - 3""", testing_stream.output) - - @jtu.sample_product(with_jit=[True, False]) - def test_tap_scan_cond(self, with_jit=True): - def func(x): - x1 = hcb_id_print(x, where="1", output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, where="2", output_stream=testing_stream) - - def body(c, x): - x3 = hcb_id_print(x, where="s_1", output_stream=testing_stream) - x4 = lax.cond(x % 2 == 0, - lambda x: hcb_id_print(x, where="s_t", output_stream=testing_stream), - lambda x: hcb_id_print(-1, where="s_f", result=x, output_stream=testing_stream), - x3 + 1) - return (c, hcb_id_print(x4, where="s_2", output_stream=testing_stream)) - - _, x10 = lax.scan(body, x2, jnp.arange(3)) - res = hcb_id_print(x10, where="10", output_stream=testing_stream) - return res - - if with_jit: - func = jax.jit(func) - res = func(1) - self.assertAllClose(jnp.arange(1, 4), res) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - where: 1 - 1 - where: 2 - 2 - where: s_1 - 0 - where: s_t - 1 - where: s_2 - 1 - where: s_1 - 1 - where: s_f - -1 - where: s_2 - 2 - where: s_1 - 2 - where: s_t - 3 - where: s_2 - 3 - where: 10 - [1 2 3]""", testing_stream.output) - testing_stream.reset() - - @jtu.sample_product( - nr_args=[1, 2], - shape=[(), (2,), (2, 3), (2, 3, 4)], - dtype=jtu.dtypes.all, - ) - def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)): - if dtype in (jnp.complex64, jnp.complex128, jnp.bool_): - raise SkipTest(f"host_callback not implemented for {dtype}.") - if dtype == np.bool_: - args = [self.rng().choice(a=[True, False], size=shape)] - else: - args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)] - if nr_args > 1: - args = args * nr_args - jit_fun1 = jax.jit(lambda xs: hcb_id_print( - xs, - a_new_test="************", - testcase_name=f"{shape=}_{dtype=}_{nr_args=}")) - - res = jit_fun1(args) - self.assertAllClose(args, res, check_dtypes=True) - - def test_tap_jit_large(self): - arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) - jax.jit(hcb_id_print)(arg) - - def test_tap_jit_several_together(self): - arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5)) - jax.jit(lambda x, y: hcb_id_print((x, y, x * 2)))(arg, jnp.ones(100, dtype=jnp.int32)) - - def test_tap_jit_interleaving(self): - # Several jit's without data dependencies; they may interfere - count = 0 # Count tap invocations - nr_arrays = 5 - - def tap_func(arg, _): - nonlocal count - assert len(arg) == nr_arrays - count += 1 - - # This is the function that we'll run multiple times - def func(x, count): - for i in range(count): - x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)])[-1] - return x - - x = jnp.array(1, dtype=np.int32) - res = 0 - for _ in range(10): - # No dependencies between the jit invocations - res += jax.jit(lambda x: func(x, 10))(x) - jax.effects_barrier() - self.assertEqual(100, count) - - def test_tap_while(self): - """Executing while, even without JIT uses compiled code""" - y = jnp.ones(5) # captured const - - def func(x): - return lax.while_loop( - lambda c: c[1] < 5, - lambda c: (y, hcb_id_print(c[1], output_stream=testing_stream) + 1), - (x, 1)) - - func(y) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - 1 - 2 - 3 - 4""", testing_stream.output) - - def test_tap_jvp(self): - jvp_fun1 = lambda x, xt: jax.jvp(fun1, (x,), (xt,)) - res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) - self.assertAllClose(100., res_primals, check_dtypes=False) - self.assertAllClose(4., res_tangents, check_dtypes=False) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: a * 2 - 10.00 - what: y * 3 - 30.00""", testing_stream.output) - - def test_tap_grad_primal_unused(self): - # The output of id_print is not needed for backwards pass - def func(x): - return 2. * hcb_id_print(x * 3., what="x * 3", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - grad_func = jax.grad(func) - arg = jnp.float32(5.) - jaxpr = str(jax.make_jaxpr(grad_func)(arg)) - # making the Jaxpr does not print anything - jax.effects_barrier() - - if hcb._HOST_CALLBACK_LEGACY.value: - treedef = jax.tree.structure(arg) - assertMultiLineStrippedEqual( - self, f""" - {{ lambda ; a:f32[]. let - b:f32[] = mul a 3.00 - c:f32[] = outside_call[ - arg_treedef={treedef} - callback=... - device_index=0 - identity=True - ] b - _:f32[] = mul 2.00 c - d:f32[] = mul 2.00 1.00 - e:f32[] = mul d 3.00 - in (e,) }}""", jaxpr) - assertMultiLineStrippedEqual(self, "", testing_stream.output) - testing_stream.reset() - - res_grad = grad_func(arg) - jax.effects_barrier() - - self.assertAllClose(6., res_grad, check_dtypes=False) - assertMultiLineStrippedEqual(self, """ - what: x * 3 - 15.00""", testing_stream.output) - - def test_tap_grad_simple(self): - def func(x): - y = hcb_id_print(x * 2., what="x * 2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * hcb_id_print(y * 3., what="y * 3", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - grad_func = jax.grad(func) - - res_grad = grad_func(jnp.float32(5.)) - self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: x * 2 - 10.00 - what: y * 3 - 30.00""", testing_stream.output) - - def test_tap_grad_grad(self): - def func(x): - y = hcb_id_print(x * 2., what="x * 2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * (y * 3.) - - grad_func = jax.grad(jax.grad(func)) - # making the Jaxpr does not print anything - _ = jax.make_jaxpr(grad_func)(5.) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, "", testing_stream.output) - - res_grad = grad_func(jnp.float32(5.)) - - self.assertAllClose(12., res_grad, check_dtypes=False) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: x * 2 - 10.00""", testing_stream.output) - - def test_tap_grad_pytree(self): - def func(x): - x4, x5 = hcb_id_print((x * 2., x * 3.), what="pair", - result=(x * 4., x * 5.), - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x4 + 2. * x5 - - x = jnp.float32(5.) - grad_func = jax.grad(func) - print(jax.make_jaxpr(grad_func)(x)) - res_grad = grad_func(x) - self.assertAllClose(14., res_grad, check_dtypes=False) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: pair - ( 10.00 15.00 )""", testing_stream.output) - - def test_tap_jvp_float0(self): - def f(x, yint): - x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint), - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * yint - - res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0))) - self.assertAllClose((6., 0.6), res) - - def test_tap_grad_float0(self): - - def func(x, yint): - x, yint = hcb_id_print((x, yint), what="pair", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x * yint.astype(x.dtype) - - grad_func = jax.grad(func) - - res_grad = grad_func(jnp.float32(5.), jnp.int32(2)) - self.assertAllClose(2., res_grad, check_dtypes=False) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - what: pair - ( 5.00 2 )""", testing_stream.output) - - def test_tap_grad_float0_result(self): - # https://github.com/jax-ml/jax/issues/7340 - # x is a Tuple[f32[2], s32[3]] - x = (np.array([.7, .8], dtype=np.float32), - np.array([11, 12, 13], dtype=np.int32)) - def f_jax(x): - x = hcb_id_print(x, result=x, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important - return (3. * x[0], x[1]) - - def f_jax_vjp(x): - res, pullback = jax.vjp(f_jax, x) - g, = pullback((np.ones(x[0].shape, dtype=x[0].dtype), - np.zeros(x[1].shape, dtype=dtypes.float0))) - return g - - g = f_jax_vjp(x) - self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0]) - self.assertEqual(dtypes.float0, g[1].dtype) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] )""", testing_stream.output) - - def test_tap_higher_order_grad_float0_result(self): - # https://github.com/jax-ml/jax/issues/7340 - # x is a Tuple[f32[2], s32[3]] - x = (np.array([.7, .8], dtype=np.float32), - np.array([11, 12, 13], dtype=np.int32)) - def f_jax(x): - x = hcb_id_print(x, result=x, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important - return (jnp.sin(x[0]), x[1]) - - def wrap_vjp(f, args, res_f_of_args): - # Given a function "f" and "args" return the f_vjp and args_vjp - def make_ct(res): - res_dtype = np.result_type(res) - if res_dtype == dtypes.float0: - return res - ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype) - return np.ones(np.shape(res), dtype=ct_dtype) - cts = jax.tree.map(make_ct, res_f_of_args) - def f_vjp(args, cts): - res, pullback = jax.vjp(f, *args) - return pullback(cts) - return (f_vjp, (args, cts)) - - res = f_jax(x) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] )""", testing_stream.output) - testing_stream.reset() - - # 1st order - f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res) - res_vjp1 = f_jax_vjp1(*args_vjp1) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - ( [0.70 0.80] [11 12 13] )""", testing_stream.output) - testing_stream.reset() - - # 2nd order - f_jax_vjp2, args_vjp2 = wrap_vjp(f_jax_vjp1, args_vjp1, res_vjp1) - res_vjp2 = f_jax_vjp2(*args_vjp2) - - # 3rd order - f_jax_vjp3, args_vjp3 = wrap_vjp(f_jax_vjp2, args_vjp2, res_vjp2) - _ = f_jax_vjp3(*args_vjp3) - - def test_tap_vmap(self): - vmap_fun1 = jax.vmap(fun1) - vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) - vmap_fun1(vargs) - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 - [ 8.00 10.00] - transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 - [24.00 30.00]""", testing_stream.output) - else: - assertMultiLineStrippedEqual(self, """ - what: a * 2 - 8.00 - what: a * 2 - 10.00 - what: y * 3 - 24.00 - what: y * 3 - 30.00 - """, testing_stream.output) - - def test_tap_vmap_not_batched(self): - x = 3. - - def func(y): - # x is not mapped, y is mapped - _, y = hcb_id_print((x, y), output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return x + y - - vmap_func = jax.vmap(func) - vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) - _ = vmap_func(vargs) - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (None, 0)})] - ( 3.00 [4.00 5.00] )""", testing_stream.output) - else: - assertMultiLineStrippedEqual(self, """ - ( 3.00 4.00 ) - ( 3.00 5.00 ) - """, testing_stream.output) - - def test_tap_vmap_vmap(self): - # A 2D tensor with x[i, j] = i + j using 2 vmap - def sum(x, y): - return hcb_id_print(x + y, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - def sum_rows(xv, y): - return jax.vmap(sum, in_axes=(0, None))(xv, y) - - def sum_all(xv, yv): - return jax.vmap(sum_rows, in_axes=(None, 0))(xv, yv) - - xv = jnp.arange(5, dtype=np.int32) - yv = jnp.arange(3, dtype=np.int32) - # assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv))) - _ = sum_all(xv, yv) - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] - [[0 1 2 3 4] - [1 2 3 4 5] - [2 3 4 5 6]]""", testing_stream.output) - else: - assertMultiLineStrippedEqual(self, """ - 0 - 1 - 2 - 1 - 2 - 3 - 2 - 3 - 4 - 3 - 4 - 5 - 4 - 5 - 6 - """, testing_stream.output) - - def test_tap_vmap_while(self): - """Vmap of while.""" - - def func(x): - # like max(x, 2) - x1 = hcb_id_print(x, where="before:x", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - x2 = lax.while_loop( - lambda x: x < 2, lambda x: hcb_id_print( - x + 1, where="body:x+1", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG), x1) - res = hcb_id_print(x2, where="after:x", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return res - - inputs = np.arange(5, dtype=np.int32) - self.assertAllClose( - np.array([2, 2, 2, 3, 4]), - jax.jit(jax.vmap(func))(inputs), - check_dtypes=False) - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual( - self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: before:x - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: after:x - [2 2 2 3 4]""", testing_stream.output) - else: - pass # order of vmaps is not guaranteed - - def test_tap_vmap_while_tap_cond(self): - """Vmap of while, with a tap in the conditional.""" - - def func(x): - # like max(x, 2) - x1 = hcb_id_print(x, where="1", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - x2 = lax.while_loop(lambda x: hcb_id_print(x < 2, where="w_c", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG), - lambda x: hcb_id_print(x + 1, where="w_b", - output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG), - x1) - res = hcb_id_print(x2, where="3", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return res - - inputs = np.arange(5, dtype=np.int32) - res = jax.jit(jax.vmap(func))(inputs) - jax.effects_barrier() - self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) - if hcb._HOST_CALLBACK_LEGACY.value: - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: 1 - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True True False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [False False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: 3 - [2 2 2 3 4]""", testing_stream.output) - else: - pass # order of vmap is not guaranteed - - def test_tap_transforms_doc(self): - # Examples from the documentation - def power3(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return y * x - - print(f"impl = {power3(3.)}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - @jax.custom_jvp - def print_tangents(arg): - return None - - @print_tangents.defjvp - def print_tangents_jvp(primals, tangents): - arg_dot, = tangents - hcb_id_print(arg_dot, what="tangents", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return primals, tangents - - def power3_with_tangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - print_tangents((x, y)) - return y * x - - print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. ) - what: tangents - ( 0.1 0.6 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - testing_stream.reset() - - print(f"grad = {jax.grad(power3)(3.)}") - jax.effects_barrier() - # Only the primals by default - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - @jax.custom_vjp - def print_cotangents(arg): - # Must return the argument for which we want the cotangent. - return arg - - # f_fwd: a -> (b, residual) - def print_cotangents_fwd(arg): - return print_cotangents(arg), None - # f_bwd: (residual, CT b) -> [CT a] - def print_cotangents_bwd(residual, ct_b): - hcb_id_print(ct_b, what="cotangents", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return ct_b, - - print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) - - def power3_with_cotangents(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb_id_print((x, y), what="x,x^2", output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - # Must use the output of print_cotangents - (x1, y1) = print_cotangents((x, y)) - return y1 * x1 - - print(f"grad = {jax.grad(power3_with_cotangents)(3.)}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. ) - what: cotangents - ( 9. 3. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 ) - what: cotangents - ( 9.0 3.0 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - # TODO: grad of grad - - print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" - else: - expected = """ - what: x,x^2 - ( 2.0 4.0 ) - what: x,x^2 - ( 3.0 9.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" - else: - expected = """ - what: x,x^2 - ( 2.0 4.0 ) - what: x,x^2 - ( 3.0 9.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] ) - transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents - ( [4. 9.] [2. 3.] )""" - else: - expected = """ - what: x,x^2 - ( 2.0 4.0 ) - what: x,x^2 - ( 3.0 9.0 ) - what: cotangents - ( 4.0 2.0 ) - what: cotangents - ( 9.0 3.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}") - jax.effects_barrier() - if hcb._HOST_CALLBACK_LEGACY.value: - expected = """ - what: x,x^2 - ( 3. 9. ) - what: x,x^2 - ( 27. 729. ) - what: x,x^2 - ( 3. 9. )""" - else: - expected = """ - what: x,x^2 - ( 3.0 9.0 ) - what: x,x^2 - ( 27.0 729.0 ) - what: x,x^2 - ( 3.0 9.0 ) - """ - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - - @unittest.skip("cond of pmap does not work in JAX. Issue #5178.") - def test_tap_cond_pmap(self): - # A matrix M[ij] = i * 10 + j - nr_devices = len(local_devices()) - shape = (nr_devices, 3) - matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, - dtype=np.float32) - - def fun1(x, do_print=False): - return maybe_print(do_print, x * 2., "x * 2") - - def fun2(cond, xv, do_print=False): - return lax.cond(cond, jax.pmap(partial(fun1, do_print=do_print)), - lambda xv: xv, xv) - - res = fun2(True, matrix) - self.assertAllClose(fun2(True, matrix, do_print=False), res, check_dtypes=False) - jax.effects_barrier() - assertMultiLineStrippedEqual(self, """ - TBD""", testing_stream.output) - - def test_tap_callback_delay(self): - hcb.callback_extra = lambda dev: time.sleep(1) - - def func(x): - for i in range(5): - x = hcb_id_print(x * i, what="x times i") - return x - - jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) - - def test_tap_callback_delay_barrier(self): - hcb.callback_extra = lambda dev: time.sleep(2) - - def func(x): - for i in range(1, 4): - x = hcb_id_print(x * i, what=f"x times {i}", output_stream=testing_stream) - return x - - jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) - # Wait for the results - jax.effects_barrier() - expected = """ - what: x times 1 - [[0. 1. 2.] - [3. 4. 5.]] - what: x times 2 - [[ 0. 2. 4.] - [ 6. 8. 10.]] - what: x times 3 - [[ 0. 6. 12.] - [18. 24. 30.]]""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - testing_stream.reset() - # Call again - jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) - jax.effects_barrier() - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - def test_tap_error_bad_consumer_id(self): - """Try to use reserved consumer ID 0. - - Check that we get the proper error from the runtime.""" - if not hcb._use_outfeed(jtu.device_under_test()): - raise SkipTest("test works only for outfeed") - comp = xla_client.XlaBuilder(self._testMethodName) - token = hcb.xops.CreateToken(comp) - hcb._initialize_outfeed_receiver() # Needed if this is the sole test - with self.assertRaisesRegex(RuntimeError, - "Consumer ID cannot be a reserved value: 0"): - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 0, - [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0) - - def test_tap_error_different_shapes(self): - """Try to register different shapes for the same consumer ID.""" - if not hcb._use_outfeed(jtu.device_under_test()): - raise SkipTest("test works only for outfeed") - comp = xla_client.XlaBuilder(self._testMethodName) - token = hcb.xops.CreateToken(comp) - hcb._initialize_outfeed_receiver() # Needed if this is the sole test - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 123, - [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0) - with self.assertRaisesRegex( - RuntimeError, ".*does not match previous shape .*\n?element_type.*"): - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 123, - [xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))], 0) - with self.assertRaisesRegex( - RuntimeError, ".*does not match previous shape .*\n?element_type.*"): - hcb._callback_handler_data.receiver.add_outfeed( - comp, token, 123, - [xops.Constant(comp, np.zeros((2,), dtype=np.float32))], 0) - - def test_tap_id_tap_removed_kwargs(self): - def func(x, transforms, y): - pass - - with self.assertRaisesRegex(TypeError, r"Support for \*\*kwargs in ``id_tap``"): - hcb.id_tap(func, 1, y=2) - - def test_tap_id_tap_random_key(self): - # See https://github.com/jax-ml/jax/issues/13949 - with jax.enable_custom_prng(): - @jax.jit - def f(x): - def tap(tap_x, _): pass - return hcb.id_tap(tap, x, result=x) - f(jax.random.PRNGKey(123)) - - def test_tap_odeint(self): - # TODO: find a smaller repro for bug #4015 - # Seems to be xla_call(scan(xla_call)), all under grad. - from jax.experimental.ode import odeint - - def f(x, t, k): - x = hcb_id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG) - return -k * x - - def loss(k=1.0): - t = jnp.linspace(0, 0.001, num=2) - xs = odeint(f, 1.0, t, k) - return xs[-1] - - jax.grad(loss)(1.0) # should not fail - - def test_tap_remat_0(self): - def f(i, k): - x = hcb_id_print(k + i, output_stream=testing_stream, - callback_flavor=hcb.CallbackFlavor.DEBUG) - return k * x - - def loss(k): - return lax.fori_loop(0, 2, jax.remat(f), k) - - print(loss(3)) - jax.effects_barrier() - expected = """ - 3 - 10""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - def test_tap_named_call(self): - def tap_scalar(init, do_print=False): - @partial(jax.named_call, name="step") - def step(acc, step_nr): - acc = acc + step_nr - maybe_print(do_print, step_nr, what="step_nr") - return acc, None - - return lax.scan(step, init, np.arange(2)) - - self.assertAllClose(tap_scalar(3, do_print=False), tap_scalar(3, do_print=True)) - jax.effects_barrier() - expected = """ - what: step_nr - 0 - what: step_nr - 1""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - class HostCallbackCallTest(jtu.JaxTestCase): """Tests for hcb.call""" @@ -1461,25 +43,10 @@ def setUp(self): self.enter_context(jtu.ignore_warning( category=DeprecationWarning, message="backend and device argument")) - testing_stream.reset() - testing_stream._test_method_name = self._testMethodName - def tearDown(self) -> None: jax.effects_barrier() super().tearDown() - def call_log_testing_stream(self, func, arg, *, result_shape, name=""): - """Call `func` and log inputs and outputs to the testing stream""" - - def call_log(arg): - def val2str(v): - return np.array2string(np.array(arg)) - testing_stream.write(f"Call {name}({val2str(arg)})\n") - res = func(arg) - testing_stream.write(f" = {val2str(res)}\n") - return res - return hcb.call(call_log, arg, result_shape=result_shape) - def test_call_simple(self): def f_outside(x): @@ -1492,20 +59,6 @@ def fun(x): arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg)) - def test_primitive_compilation(self): - - def f_outside(x): - return 2 * x - - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - with jtu.count_primitive_compiles() as count: - for _ in range(3): - self.assertAllClose(2 * arg, fun(arg)) - r = jax.make_jaxpr(fun)(arg) - self.assertEqual(count[0], 1) @jtu.sample_product( dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_], @@ -1546,346 +99,6 @@ def fun(x): arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) self.assertAllClose(2 * (arg + 1) + 3 * arg, fun(arg)) - def test_call_no_arg(self): - """Call with no arguments.""" - result = np.ones((2,), dtype=np.float32) - def f_outside(in_tuple): - assert len(in_tuple) == 0 - return result - def fun(x): - return x + hcb.call(f_outside, (), - result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype)) - self.assertAllClose(2. + result, fun(2.)) - - def test_call_empty_arg(self): - """Call with empty array.""" - result = np.full((2,), 3., dtype=np.float32) - def f_outside(x0): # x0: f32[2, 0] - return result - x0 = np.ones((2, 0), dtype=np.float32) - def fun(x): - return x + hcb.call(f_outside, x0, - result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype)) - self.assertAllClose(2. + result, fun(2.)) - - def test_call_empty_arg_inside_pytree(self): - """Call taking tuple with an empty array and a non-empty one.""" - x0 = np.ones((2, 0), dtype=np.float32) - x1 = np.full((2,), 3., dtype=np.float32) - result = x1 - def f_outside(in_tuple): # x0: f32[2, 0] x1: f32[2] - return in_tuple[1] - - def fun(x): - res = hcb.call(f_outside, (x0, x1), - result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype)) - return x + res - self.assertAllClose(2. + result, fun(2.)) - - def test_call_empty_result(self): - """Call returning empty array.""" - result_shape = (2, 0) - def f_outside(_): - return np.ones(result_shape, dtype=np.float32) - def fun(x): - return x + hcb.call(f_outside, 1., - result_shape=jax.ShapeDtypeStruct(result_shape, np.float32)) - self.assertAllClose(f_outside(0.), fun(2.)) - - def test_call_empty_result_inside_pytree(self): - """Call returning a tuple with an empty array and a non-empty one.""" - result_shape_0 = (2, 0) - result_shape_2 = (0,) - def f_outside(_): - return (np.ones(result_shape_0, dtype=np.float32), - np.ones((1,), dtype=np.float32), - np.ones(result_shape_2, dtype=np.float32)) - def fun(x): - res = hcb.call(f_outside, 1., - result_shape=(jax.ShapeDtypeStruct(result_shape_0, np.float32), - jax.ShapeDtypeStruct((1,), np.float32), - jax.ShapeDtypeStruct(result_shape_2, np.float32))) - self.assertEqual(result_shape_0, res[0].shape) - self.assertEqual(result_shape_2, res[2].shape) - return x + res[1] - self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.)) - - def test_call_empty_result_all_pytree(self): - """Call returning a tuple of empty arrays.""" - result_shape = (2, 0) - def f_outside(_): - return (np.ones(result_shape, dtype=np.float32), - np.ones(result_shape, dtype=np.float32)) - def fun(x): - res = hcb.call(f_outside, 1., - result_shape=(jax.ShapeDtypeStruct(result_shape, np.float32), - jax.ShapeDtypeStruct(result_shape, np.float32))) - return x + res[0] + res[1] - self.assertAllClose(np.ones(result_shape, dtype=np.float32), - fun(2.)) - - def test_call_no_result(self): - def f_outside(arg): - self.call_log_testing_stream(lambda x: None, arg, - result_shape=None, - name="outside") - return arg - - self.assertAllClose((3., 4.), f_outside((3., 4.))) - jax.effects_barrier() - expected = """ - Call outside([3. 4.]) - = [3. 4.]""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - - def test_call_cond(self): - def f_outside(args): - x, y = args - return x * y.astype(np.float32) - - def loop(x, use_outside=True): - def body(i, acc): - return lax.cond(i % 2 == 1, - lambda _: (hcb.call(f_outside, (acc, i), - result_shape=acc) - if use_outside else f_outside((acc, i))), - lambda _: acc, - None) - - return lax.fori_loop(0, 18, body, x) - - res_inside = loop(np.float32(1.2), use_outside=False) - self.assertAllClose(res_inside, jax.jit(loop)(np.float32(1.2))) - - def test_call_jit_scan_call(self): - def f_outside(x): - return x - - def loop(x, use_outside=True): - def body(carry, i): - if use_outside: - return carry + hcb.call(f_outside, i, - result_shape=i), None - else: - return carry + i, None - - return lax.scan(body, 0, x) - - x = np.arange(5, dtype=np.int32) - - res_outside = jax.jit(partial(loop, use_outside=True))(x) - self.assertAllClose(res_outside, loop(x, use_outside=False)) - - def test_call_doc_example1(self): - """Examples from the documentation: simplest, call a function""" - - def host_eig(x): - return np.linalg.eigvals(x) - - shape = (2, 5, 4, 4) - - m = np.ones(shape, dtype=np.float32) - - def fun(m): - eig_m = hcb.call(host_eig, m, - result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype)) - return eig_m - - expected_res = np.linalg.eigvals(m) - self.assertAllClose(expected_res, fun(m)) - @jtu.skip_on_devices("gpu") - def test_call_doc_example_hlo(self): - """Examples from the documentation: simplest, call a function.""" - - def fun1(m): - return jnp.sin(hcb.call(lambda x: np.cos, - jnp.cos(m), - result_shape=m)) - - m = np.ones((2,), np.float32) - helper_print_optimized_hlo(fun1, m) - - def fun2(m): - x = hcb.call(lambda x: None, 2, result_shape=()) - return x - - m = np.ones((2,), np.float32) - helper_print_optimized_hlo(fun2, m) - - def test_call_vmap(self): - def f_outside(x): return x - - def fun(x): - return hcb.call(f_outside, x, result_shape=x, - callback_flavor=hcb.CallbackFlavor.PURE) - - if hcb._HOST_CALLBACK_LEGACY.value: - with self.assertRaisesRegex(NotImplementedError, - "batching rules are implemented only for id_tap, not for call"): - jax.vmap(fun)(np.ones((2, 3))) - else: - with jtu.ignore_warning(category=DeprecationWarning): - jax.vmap(fun)(np.ones((2, 3))) - - def test_call_error_bad_result_shape(self): - with self.assertRaisesRegex( - ValueError, - "The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"): - hcb.call(lambda x: x, 3., result_shape="string") - - with self.assertRaisesRegex( - ValueError, - "The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"): - hcb.call(lambda x: x, 3., result_shape=lambda x: x) - jax.effects_barrier() - - def helper_check_callback_errors(self, thunk: Callable, - expected_exc_txt: str): - """Calls thunk() and checks for expected exceptions. - """ - if jtu.test_device_matches(["cpu"]): - # On CPU the runtime crashes, and the tests are all aborted - raise SkipTest("TODO: CPU runtime crashes on unexpected infeed") - elif jtu.test_device_matches(["gpu"]): - # On GPU we get a nice error back to Python - with self.assertRaisesRegex( - RuntimeError, - "(.* Mismatch between infeed source buffer shape s8.12345." - "|.*The destination shape does not match the source shape.)"): - thunk() - elif jtu.test_device_matches(["tpu"]): - # On TPU we get no error!!! - raise SkipTest("TODO: TPU runtime does not check infeed, and just computes with garbage") - - # Both on GPU and TPU we also get an error during the barrier_wait at the - # end of the test. Run a barrier_wait now, to consume that error. - with self.assertRaisesRegex( - hcb.CallbackException, - re.compile( - "There were exceptions during callback processing.*Last one was:.*" + - expected_exc_txt, - re.DOTALL)): - jax.effects_barrier() - - -def call_jax_other_device( - jax_outside_fun, arg, *, device, - callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK): - """Calls a JAX function on a specific device with simple support for reverse AD. - - Functions whose name starts with "jax_outside" are called on another device, - by way of hcb.call. - """ - - def run_jax_outside_fun(arg): - return jax.jit(jax_outside_fun)(jax.device_put(arg, device)) - - @jax.custom_vjp - def make_call(arg): - return hcb.call(run_jax_outside_fun, arg, - result_shape=jax.eval_shape(jax_outside_fun, arg), - callback_flavor=callback_flavor) - - # Define the fwd and bwd custom_vjp functions - def make_call_vjp_fwd(arg): - # Return the primal argument as the residual. Use `make_call` for the - # primal computation to enable higher-order AD. - return make_call(arg), arg # Return the primal argument as the residual - - def make_call_vjp_bwd(res, ct_res): - arg = res # residual is the primal argument - - def jax_outside_vjp_fun(arg_and_ct): - arg, ct = arg_and_ct - _, f_vjp = jax.vjp(jax_outside_fun, arg) - ct_in, = f_vjp(ct) - return ct_in - - return (call_jax_other_device(jax_outside_vjp_fun, (arg, ct_res), device=device),) - - make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) - return make_call(arg) - - -class CallJaxTest(jtu.JaxTestCase): - """Tests using `call_jax_other_device`.""" - - def setUp(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - - if not jtu.test_device_matches(["cpu"]): - assert jax.devices("cpu") - self.outside_device = jax.devices("cpu")[0] - else: - if len(jax.devices("cpu")) == 1: - raise SkipTest("Test needs at least two devices. On CPU use XLA_FLAGS=--xla_force_host_platform_device_count=2") - self.outside_device = jax.devices("cpu")[1] - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - - - def test_jax_impl(self): - def f_jax(x): - return jnp.sin(x) - - def f_outside(x): - return call_jax_other_device(f_jax, x, device=self.outside_device) - - self.assertAllClose(f_jax(3.), f_outside(3.)) - self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.)) - - def test_jax_impl_pytree(self): - def f_jax(x): - # x : dict(a=..., b=...) and output is a list of two elements - return [jnp.sin(x["a"]), jnp.sin(x["b"])] - - def f_outside(x): - return call_jax_other_device(f_jax, x, device=self.outside_device) - - x = dict(a=3., b=4.) - res_jax = f_jax(x) - # print(f"outside_jaxpr = {jax.make_jaxpr(f_outside)(x)}") - res_outside = f_outside(x) - self.assertAllClose(res_jax, res_outside) - - def test_jax_grad(self): - def f_jax(x): - return 2. * jnp.sin(x) - - def f_outside(x): - return 2. * call_jax_other_device(jnp.sin, x, device=self.outside_device) - - res_jax = jax.grad(f_jax)(3.) - self.assertAllClose(res_jax, jax.grad(f_outside)(3.)) - - def test_jax_grad_pytree(self): - def f_jax(x): - # x : dict(a=..., b=...) and output is a float - return 3. * jnp.sin(x["a"]) + jnp.sin(x["b"]) - - def f_outside(x): - return call_jax_other_device(f_jax, x, device=self.outside_device) - - x = dict(a=3., b=4.) - res_jax = jax.grad(f_jax)(x) - self.assertAllClose(res_jax, jax.grad(f_outside)(x)) - - def test_jax_grad_of_grad(self): - def f_jax(x): - return 2. * x * x * x - - def f_outside(x): - return 2. * call_jax_other_device(lambda x: x * x * x, x, device=self.outside_device) - - res_jax = jax.grad(jax.grad(f_jax))(5.) - res_outside = jax.grad(jax.grad(f_outside))(5.) - self.assertAllClose(res_jax, res_outside) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py deleted file mode 100644 index 3a36ce1296a6..000000000000 --- a/tests/host_callback_to_tf_test.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""An example of using host_callback.call to invoke on the host functions -written in Tensorflow. The interesting aspect here is how we can differentiate -through the outside computation, using tf.GradientTape on the host. - -This is separate from host_callback_test because it needs a TF dependency. -""" -from collections.abc import Callable -import unittest - -from absl.testing import absltest -from absl.testing import parameterized - -import jax -from jax import numpy as jnp -from jax._src import config -from jax._src import test_util as jtu -from jax._src import xla_bridge -from jax.experimental import host_callback as hcb - -import numpy as np - -try: - import tensorflow as tf -except ImportError: - tf = None - -config.parse_flags_with_absl() - - -def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape): - """The simplest implementation of calling to TF, without AD support. - - We must use hcb.call because the TF invocation must happen outside the - JAX staged computation.""" - - def tf_to_numpy(t): - # Turn the Tensor to NumPy array without copying. - return np.asarray(memoryview(t)) if isinstance(t, tf.Tensor) else t - - return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy, - tf_fun(arg)), - arg, result_shape=result_shape, - callback_flavor=hcb.CallbackFlavor.DEBUG) - - -def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape): - """Calls a TensorFlow function with simple support for reverse AD. - - Works only for 1st order AD and only for arguments and results being a single - ndarray (no pytrees). Functions whose name starts with "tf_" are TensorFlow - functions and must be called outside the JAX computation. - """ - - @jax.custom_vjp - def make_call(arg): - """We wrap it all in `make_call` so that we can attach custom VJP.""" - return call_tf_no_ad(tf_fun, arg, result_shape=result_shape) - - # Define the fwd and bwd custom_vjp functions - def make_call_vjp_fwd(arg): - # Return the primal argument as the residual. Use `make_call` for the - # primal computation to enable higher-order AD. - return make_call(arg), arg - - def make_call_vjp_bwd(res, ct_res): - arg = res # residual is the primal argument - - def tf_vjp_fun(arg_and_ct_res): - """Invoke TF gradient; used with hcb.call.""" - arg, ct_res = arg_and_ct_res - arg_var = tf.Variable(arg) - with tf.GradientTape(persistent=True) as tape: - res = tf_fun(arg_var) - - dres_darg = tape.gradient(res, sources=arg_var, - output_gradients=ct_res, - unconnected_gradients=tf.UnconnectedGradients.ZERO) - return dres_darg - - return (call_tf_simple_ad(tf_vjp_fun, (arg, ct_res), - result_shape=arg),) - - make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) - return make_call(arg) - - -def call_tf_full_ad(tf_fun: Callable, arg, *, result_shape): - """Calls a TensorFlow function with support for reverse AD. - - Supports higher-order AD and pytree arguments. - """ - - @jax.custom_vjp - def make_call(arg): - """We wrap it all in `make_call` so that we can attach custom VJP.""" - return call_tf_no_ad(tf_fun, arg, result_shape=result_shape) - - # Define the fwd and bwd custom_vjp functions - def make_call_vjp_fwd(arg): - return make_call(arg), arg # Return the primal argument as the residual - - def make_call_vjp_bwd(res, ct_res): - arg = res # residual is the primal argument - - def tf_vjp_fun(arg_and_ct_res): - """Invoke TF gradient; used with hcb.call.""" - arg, ct_res = arg_and_ct_res - - def make_var(a): - return a if isinstance(a, tf.Variable) else tf.Variable(a) - - arg_var = tf.nest.map_structure(make_var, arg) - - with tf.GradientTape(persistent=True) as tape: - res = tf_fun(arg_var) - - tf.nest.assert_same_structure(res, ct_res) - accumulator = None # Accumulate argument cotangent. Same structure as "arg" - - def acc_ct(res_, ct_res_): - dres_darg = tape.gradient(res_, sources=arg_var, - unconnected_gradients=tf.UnconnectedGradients.ZERO) - tf.nest.assert_same_structure(dres_darg, arg) - scaled_dres_darg = tf.nest.map_structure(lambda d: d * ct_res_, dres_darg) - nonlocal accumulator - accumulator = (scaled_dres_darg if accumulator is None - else tf.nest.map_structure(lambda x, y: x + y, - accumulator, scaled_dres_darg)) - - tf.nest.map_structure(acc_ct, res, ct_res) - return accumulator - - return (call_tf_full_ad(tf_vjp_fun, (arg, ct_res), - result_shape=arg),) - - make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) - return make_call(arg) - - -CALL_TF_IMPLEMENTATIONS = { - "none": call_tf_no_ad, - "simple": call_tf_simple_ad, - "full": call_tf_full_ad, -} - - -class CallToTFTest(jtu.JaxTestCase): - - def setUp(self): - if tf is None: - raise unittest.SkipTest("Test requires tensorflow") - if xla_bridge.using_pjrt_c_api(): - raise unittest.SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - - @parameterized.named_parameters( - dict( - testcase_name=f"_{ad=}", - ad=ad) - for ad in CALL_TF_IMPLEMENTATIONS.keys()) - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_impl(self, ad="simple"): - self.supported_only_in_legacy_mode() - call_tf = CALL_TF_IMPLEMENTATIONS[ad] - - def f_jax(x): - return jnp.sin(x) - - def f_outside(x): - return call_tf(tf.math.sin, x, - result_shape=x) - - res = f_outside(3.) - self.assertAllClose(f_jax(3.), res) - self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.)) - - @parameterized.named_parameters( - dict( - testcase_name=f"_{ad=}", - ad=ad) - for ad in CALL_TF_IMPLEMENTATIONS.keys() - if ad != "none") - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_grad(self, ad="simple"): - self.supported_only_in_legacy_mode() - call_tf = CALL_TF_IMPLEMENTATIONS[ad] - - def f_jax(x): - return 3. * jnp.sin(2. * x) - - def f_outside(x): - return 3. * call_tf( - lambda x: tf.cast(tf.math.sin(x), tf.float32), 2. * x, - result_shape=jax.ShapeDtypeStruct((), np.float32)) - - x = np.float32(4.) - self.assertAllClose(f_jax(x), f_outside(x), - check_dtypes=False) - - grad_f = jax.grad(f_outside)(x) - self.assertAllClose(jax.grad(f_jax)(x), grad_f, - check_dtypes=False) - - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_grad_pytree(self): - self.supported_only_in_legacy_mode() - call_tf = call_tf_full_ad - - def f_jax(xy): - dict_ab = dict(a=2. * xy[0], b=xy[0] * xy[1]) - return 3. * dict_ab["a"] + 4. * dict_ab["b"] - - def f_outside(xy): - dict_ab = call_tf( - lambda xy: dict(a=tf.cast(2. * xy[0], np.float32), - b=tf.cast(xy[0] * xy[1], np.float32)), - xy, - result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32), - b=jax.ShapeDtypeStruct((), np.float32))) - return 3. * dict_ab["a"] + 4. * dict_ab["b"] - - xy = (5., 6.) - self.assertAllClose(f_jax(xy), f_outside(xy), - check_dtypes=False) - res_jax = jax.grad(f_jax)(xy) - self.assertAllClose(res_jax, jax.grad(f_outside)(xy), - check_dtypes=False) - - @parameterized.named_parameters( - dict( - testcase_name=f"_degree=_{degree}", - degree=degree) - for degree in [1, 2, 3, 4]) - @jtu.ignore_warning(message="The host_callback APIs are deprecated", - category=DeprecationWarning) - def test_higher_order_grad(self, degree=4): - self.supported_only_in_legacy_mode() - call_tf = call_tf_full_ad - - def f_jax(x): - return 2. * x * x * x - - def f_outside(x): - return 2. * call_tf(lambda y: y * y * y, x, - result_shape=x) - - grad_jax = f_jax - grad_outside = f_outside - for i in range(degree): - grad_jax = jax.grad(grad_jax) - grad_outside = jax.grad(grad_outside) - - res_jax = grad_jax(5.) - self.assertAllClose(res_jax, grad_outside(5.)) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/infeed_test.py b/tests/infeed_test.py index ba47d2417f94..e378fe37a2f5 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -19,7 +19,6 @@ from absl.testing import absltest import jax from jax import lax, numpy as jnp -from jax.experimental import host_callback as hcb from jax._src import core from jax._src import xla_bridge from jax._src.lib import xla_client @@ -77,7 +76,6 @@ def f(x): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeedThenOutfeed(self): - hcb._deprecated_stop_outfeed_receiver() @jax.jit def f(x): @@ -99,7 +97,6 @@ def f(x): self.assertAllClose(out, y + np.float32(1)) def testInfeedThenOutfeedInALoop(self): - hcb._deprecated_stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed(