From f16956ee92ca0b038fd372b5c3fe212847e74492 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 29 Mar 2024 13:36:20 +0200 Subject: [PATCH] [callback] Add a flag to implement host_callback in terms of io_callback. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue #20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks. --- CHANGELOG.md | 4 +- jax/BUILD | 6 +- jax/experimental/host_callback.py | 116 ++++++- tests/host_callback_test.py | 490 +++++++++++++++++++++--------- tests/host_callback_to_tf_test.py | 16 +- tests/python_callback_test.py | 2 +- 6 files changed, 463 insertions(+), 171 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ab4ea6fc3e20..24860825a39e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,8 @@ Remember to align the itemized text with the first line of an item within a list `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the + new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` that cannot be converted to a JAX array now results in an exception. * The deprecated flag `jax_parallel_functions_output_gda` has been removed. @@ -1451,7 +1453,7 @@ Changes: special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS`` - environment variable, or the ```--flax_host_callback_ad_transforms``` flag. + environment variable, or the ```--jax_host_callback_ad_transforms``` flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}`#8678`). * Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the diff --git a/jax/BUILD b/jax/BUILD index 7998d028f206..d2bb6d31d1dd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -997,7 +997,11 @@ pytype_library( pytype_library( name = "experimental_host_callback", - srcs = ["experimental/host_callback.py"], + srcs = [ + "experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False + "experimental/host_callback.py", + "experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False + ], visibility = ["//visibility:public"], deps = [ ":jax", diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index da2dbc79d964..d50a8f7fe8a2 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -17,6 +17,7 @@ The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. This module introduces the host callback functions :func:`call`, :func:`id_tap`, and :func:`id_print`, that send their arguments from the device @@ -501,6 +502,7 @@ def power3_with_cotangents(x): from __future__ import annotations import atexit +import enum from collections.abc import Sequence import functools import itertools @@ -510,6 +512,7 @@ def power3_with_cotangents(x): import traceback from typing import Any, Callable, cast +import jax from jax._src import api from jax._src import core from jax._src import config @@ -517,6 +520,7 @@ def power3_with_cotangents(x): from jax._src import dtypes from jax import lax from jax.experimental import pjit +from jax.experimental import io_callback from jax._src.interpreters import ad, batching, pxla from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -560,6 +564,15 @@ def power3_with_cotangents(x): 'Has no effect on TPU, since only the outfeed mechanism is implemented.' ) ) +_HOST_CALLBACK_LEGACY = config.DEFINE_bool( + 'jax_host_callback_legacy', + config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), + help=( + 'Use old implementation of host_callback, documented in the module docstring.' + 'If False, use the jax.experimental.io_callback implementation. ' + 'See https://github.com/google/jax/issues/20385.' + ) +) logger = logging.getLogger(__name__) @@ -591,6 +604,15 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): XlaLocalClient = xla_client.Client DType = Any +class CallbackFlavor(enum.Enum): + """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. + + See https://github.com/google/jax/issues/20385. + """ + IO_CALLBACK = 1 # uses jax.experimental.io_callback + PURE = 2 # uses jax.pure_callback + DEBUG = 3 # uses jax.debug.callback, valid only when there are no results + def _deprecated_id_tap(tap_func, arg, @@ -598,6 +620,7 @@ def _deprecated_id_tap(tap_func, result=None, tap_with_device=False, device_index=0, + callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. @@ -605,6 +628,7 @@ def _deprecated_id_tap(tap_func, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime @@ -628,6 +652,9 @@ def _deprecated_id_tap(tap_func, device_index: specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. + callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies + the flavor of callback to use. + See https://github.com/google/jax/issues/20385. Returns: ``arg``, or ``result`` if given. @@ -660,7 +687,8 @@ def _deprecated_id_tap(tap_func, call_with_device=tap_with_device, result_shape=None, identity=True, - device_index=device_index) + device_index=device_index, + callback_flavor=callback_flavor) if result is not None: return result @@ -675,6 +703,7 @@ def _deprecated_id_print(arg, device_index=0, output_stream=None, threshold=None, + callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs): """Like :func:`id_tap` with a printing tap function. @@ -682,6 +711,7 @@ def _deprecated_id_print(arg, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. On each invocation of the printing tap, the ``kwargs`` if present will be printed first (sorted by keys). Then arg will be printed, @@ -697,6 +727,9 @@ def _deprecated_id_print(arg, built-in ``print``. The string will be passed as ``output_stream.write(s)``. * ``threshold`` is passed to ``numpy.array2string``. + * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies + the flavor of callback to use. + See https://github.com/google/jax/issues/20385. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ @@ -708,19 +741,22 @@ def _deprecated_id_print(arg, arg, result=result, tap_with_device=tap_with_device, - device_index=device_index) + device_index=device_index, + callback_flavor=callback_flavor) def _deprecated_call(callback_func: Callable, arg, *, result_shape=None, call_with_device=False, - device_index=0): + device_index=0, + callback_flavor=CallbackFlavor.IO_CALLBACK): """Make a call to the host, and expect a result. .. warning:: The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ + See https://github.com/google/jax/issues/20385. Args: callback_func: The Python function to invoke on the host as @@ -748,14 +784,26 @@ def _deprecated_call(callback_func: Callable, arg, *, device_index: specifies from which device the tap function is invoked in a SPMD program. Works only when using the outfeed implementation mechanism, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. + callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies + the flavor of callback to use. + See https://github.com/google/jax/issues/20385. + Returns: the result of the ``callback_func`` invocation. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ + if (not _HOST_CALLBACK_LEGACY.value and + callback_flavor == CallbackFlavor.DEBUG and + result_shape is not None): + raise NotImplementedError( + "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " + "flavor of callback only when the `result_shape` is None. " + "See https://github.com/google/jax/issues/20385." + ) return _call(callback_func, arg, result_shape=result_shape, call_with_device=call_with_device, identity=False, - device_index=device_index) + device_index=device_index, callback_flavor=callback_flavor) # We need the wrapper function to have hash and equality defined since it is @@ -766,6 +814,11 @@ def __init__(self, callback_func, identity, call_with_device): self.callback_func = callback_func self.identity = identity self.call_with_device = call_with_device + if not _HOST_CALLBACK_LEGACY.value and call_with_device: + raise NotImplementedError( + "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" + " do not support `tap_with_device` and `call_with_device`. " + "See https://github.com/google/jax/issues/20385.") def __hash__(self): return hash((self.callback_func, self.identity, self.call_with_device)) @@ -775,7 +828,16 @@ def __eq__(self, other): self.identity == other.identity and self.call_with_device == other.call_with_device) - def __call__(self, arg, device, transforms): + def __call__(self, *args, **kwargs): + if _HOST_CALLBACK_LEGACY.value: + return self._call_legacy(*args, **kwargs) + else: + if self.identity: + # For id_tap, we pass empty transforms, for backwards compatibility + return self.callback_func(args[0], ()) + return self.callback_func(*args, **kwargs) + + def _call_legacy(self, arg, device, transforms): if self.identity: # For id_tap, we pass the transforms, for backwards compatibility if self.call_with_device: @@ -797,14 +859,16 @@ def _call(callback_func: Callable, result_shape=None, call_with_device=False, device_index=0, - identity=False): - # Lazy initialization - _initialize_outfeed_receiver( - max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) + identity=False, + callback_flavor=CallbackFlavor.IO_CALLBACK): + if _HOST_CALLBACK_LEGACY.value: + # Lazy initialization + _initialize_outfeed_receiver( + max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) api.check_callable(callback_func) flat_args, arg_treedef = tree_util.tree_flatten(arg) - for arg in flat_args: - dispatch.check_arg(arg) + for arg_ in flat_args: + dispatch.check_arg(arg_) # See definition of outside_call_p for what parameters it takes params: dict[str, Any] = {} # TODO: wrap function @@ -829,8 +893,27 @@ def _call(callback_func: Callable, params["result_treedef"] = result_treedef params["flat_results_aval"] = tuple(flat_results_aval) - flat_results = outside_call_p.bind(*flat_args, **params) - return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) + + if _HOST_CALLBACK_LEGACY.value: + flat_results = outside_call_p.bind(*flat_args, **params) + return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results) + else: + callback_device = jax.local_devices()[device_index] + sharding = jax.sharding.SingleDeviceSharding(callback_device) + callback_func = _CallbackWrapper(callback_func, identity, + call_with_device) + if callback_flavor == CallbackFlavor.DEBUG: + assert identity + jax.debug.callback(callback_func, arg) + return arg + elif callback_flavor == CallbackFlavor.PURE: + call_res = jax.pure_callback(callback_func, result_shape, arg, + sharding=sharding) + else: + call_res = io_callback(callback_func, result_shape, arg, + sharding=sharding, + ordered=True) + return call_res if not identity else arg # We need the lock for when we use the CustomCall implementation of callbacks. @@ -855,7 +938,6 @@ def _print_tap_func( threshold: the value of numpy.array2string threshold parameter. **kwargs: all other keyword args are printed before printing `arg`. """ - def emit_str(s: str): if output_stream is not None: output_stream.write(s + "\n") @@ -1844,6 +1926,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None): For more details see the :mod:`jax.experimental.host_callback` module documentation. """ + if not _HOST_CALLBACK_LEGACY.value: + jax.effects_barrier() + return + logging_name = logging_name or "" logger.debug("barrier_wait[%s]: start", logging_name) @@ -1907,7 +1993,7 @@ def _deprecated_stop_outfeed_receiver(): _deprecation_msg = ( "The host_callback APIs are deprecated as of March 20, 2024. The functionality " "is subsumed by the new JAX external callbacks. " - "See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.") + "See https://github.com/google/jax/issues/20385.") _deprecations = { # Added March 20, 2024 diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index f12218449331..9c5ab78cbd9e 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -91,8 +91,10 @@ def reset(self): def fun1(a): """Function used for several `id_tap` tests.""" - y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream) - y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y) + y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) + y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y, + callback_flavor=hcb.CallbackFlavor.DEBUG) return y ** 2 # Some computation to make the gradient interesting @@ -253,6 +255,10 @@ def tearDown(self) -> None: hcb.barrier_wait("HostCallbackTapTest.tearDown") super().tearDown() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + def test_tap_eval(self): self.assertAllClose((5. * 2.) ** 2, fun1(5.)) hcb.barrier_wait() @@ -320,6 +326,7 @@ def func2(x): testing_stream.output) def test_tap_with_device(self): + self.supported_only_in_legacy_mode() def func2(x): x1 = hcb.id_print((x * 2., x * 3.), result=x * 4., output_stream=testing_stream, @@ -335,6 +342,7 @@ def func2(x): def test_tap_eval_exception(self): if not hcb._HOST_CALLBACK_OUTFEED.value: raise SkipTest("TODO: implement error handling for customcall") + # Simulate a tap error def tap_err(*args, **kwargs): raise ValueError("Some user message") @@ -345,19 +353,30 @@ def func(x): x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 - with self.assertRaisesRegex( - hcb.CallbackException, - re.compile("There were exceptions during callback processing. Last one was:.*" - "ValueError: Some user message", re.DOTALL)): + if hcb._HOST_CALLBACK_LEGACY.value: + ctx = self.assertRaisesRegex( + hcb.CallbackException, + re.compile("There were exceptions during callback processing. Last one was:.*" + "ValueError: Some user message", re.DOTALL)) + else: + ctx = self.assertRaisesRegex(Exception, "Some user message") + + with ctx: func(0) hcb.barrier_wait() - # We should have received everything before the error - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + # We should have received everything before the error + assertMultiLineStrippedEqual(self, """ + what: x1 + 1 + what: x3 + 3""", testing_stream.output) + else: + # We should have received everything before the error + assertMultiLineStrippedEqual(self, """ + what: x1 + 1""", testing_stream.output) def test_tap_empty(self): """Tap empty arrays.""" @@ -488,6 +507,7 @@ def func_nested(x): def test_tap_jit_devices(self): """Running on multiple devices.""" + self.supported_only_in_legacy_mode() logging.info("%s: has devices %s", self._testMethodName, local_devices()) def func(x, device_id): @@ -830,19 +850,24 @@ def func(x): x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 - res = jax.jit(func)(0) # No error yet - with self.assertRaises(hcb.CallbackException): - hcb.barrier_wait() - - # Even though the receiver thread raised, the main thread should still - # return 3. - self.assertEqual(3, res) - # We should have received all others - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + res = jax.jit(func)(0) # No error yet + with self.assertRaises(hcb.CallbackException): + hcb.barrier_wait() + + # Even though the receiver thread raised, the main thread should still + # return 3. + self.assertEqual(3, res) + # We should have received all others + assertMultiLineStrippedEqual(self, """ + what: x1 + 1 + what: x3 + 3""", testing_stream.output) + else: + with self.assertRaisesRegex(Exception, "NotImplementedError"): + res = jax.jit(func)(0) + hcb.barrier_wait() def test_tap_while(self): """Executing while, even without JIT uses compiled code""" @@ -878,7 +903,8 @@ def test_tap_grad_primal_unused(self): # The output of id_print is not needed for backwards pass def func(x): return 2. * hcb.id_print(x * 3., what="x * 3", - output_stream=testing_stream) + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) grad_func = jax.grad(func) arg = jnp.float32(5.) @@ -886,21 +912,22 @@ def func(x): # making the Jaxpr does not print anything hcb.barrier_wait() - treedef = jax.tree.structure(arg) - assertMultiLineStrippedEqual( - self, f""" - {{ lambda ; a:f32[]. let - b:f32[] = mul a 3.00 - c:f32[] = outside_call[ - arg_treedef={treedef} - callback=... - device_index=0 - identity=True - ] b - _:f32[] = mul 2.00 c - d:f32[] = mul 2.00 1.00 - e:f32[] = mul d 3.00 - in (e,) }}""", jaxpr) + if hcb._HOST_CALLBACK_LEGACY.value: + treedef = jax.tree.structure(arg) + assertMultiLineStrippedEqual( + self, f""" + {{ lambda ; a:f32[]. let + b:f32[] = mul a 3.00 + c:f32[] = outside_call[ + arg_treedef={treedef} + callback=... + device_index=0 + identity=True + ] b + _:f32[] = mul 2.00 c + d:f32[] = mul 2.00 1.00 + e:f32[] = mul d 3.00 + in (e,) }}""", jaxpr) assertMultiLineStrippedEqual(self, "", testing_stream.output) testing_stream.reset() @@ -914,9 +941,11 @@ def func(x): def test_tap_grad_simple(self): def func(x): - y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) + y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * hcb.id_print(y * 3., what="y * 3", - output_stream=testing_stream) + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) grad_func = jax.grad(func) @@ -931,7 +960,8 @@ def func(x): def test_tap_grad_grad(self): def func(x): - y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) + y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * (y * 3.) grad_func = jax.grad(jax.grad(func)) @@ -952,7 +982,8 @@ def test_tap_grad_pytree(self): def func(x): x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair", result=(x * 4., x * 5.), - output_stream=testing_stream) + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x4 + 2. * x5 x = jnp.float32(5.) @@ -967,15 +998,18 @@ def func(x): def test_tap_jvp_float0(self): def f(x, yint): - x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint)) + x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint), + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * yint res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0))) self.assertAllClose((6., 0.6), res) def test_tap_grad_float0(self): + def func(x, yint): - x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream) + x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x * yint.astype(x.dtype) grad_func = jax.grad(func) @@ -993,7 +1027,8 @@ def test_tap_grad_float0_result(self): x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) def f_jax(x): - x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important + x = hcb.id_print(x, result=x, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important return (3. * x[0], x[1]) def f_jax_vjp(x): @@ -1015,7 +1050,8 @@ def test_tap_higher_order_grad_float0_result(self): x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) def f_jax(x): - x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important + x = hcb.id_print(x, result=x, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) # result= is important return (jnp.sin(x[0]), x[1]) def wrap_vjp(f, args, res_f_of_args): @@ -1059,32 +1095,52 @@ def test_tap_vmap(self): vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) vmap_fun1(vargs) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 - [ 8.00 10.00] - transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 - [24.00 30.00]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (0,)})] what: a * 2 + [ 8.00 10.00] + transforms: [('batch', {'batch_dims': (0,)})] what: y * 3 + [24.00 30.00]""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + what: a * 2 + 8.00 + what: a * 2 + 10.00 + what: y * 3 + 24.00 + what: y * 3 + 30.00 + """, testing_stream.output) def test_tap_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped - _, y = hcb.id_print((x, y), output_stream=testing_stream) + _, y = hcb.id_print((x, y), output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return x + y vmap_func = jax.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) _ = vmap_func(vargs) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (None, 0)})] - ( 3.00 [4.00 5.00] )""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (None, 0)})] + ( 3.00 [4.00 5.00] )""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + ( 3.00 4.00 ) + ( 3.00 5.00 ) + """, testing_stream.output) def test_tap_vmap_vmap(self): # A 2D tensor with x[i, j] = i + j using 2 vmap def sum(x, y): - return hcb.id_print(x + y, output_stream=testing_stream) + return hcb.id_print(x + y, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) def sum_rows(xv, y): return jax.vmap(sum, in_axes=(0, None))(xv, y) @@ -1097,22 +1153,44 @@ def sum_all(xv, yv): # assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv))) _ = sum_all(xv, yv) hcb.barrier_wait() - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] - [[0 1 2 3 4] - [1 2 3 4 5] - [2 3 4 5 6]]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})] + [[0 1 2 3 4] + [1 2 3 4 5] + [2 3 4 5 6]]""", testing_stream.output) + else: + assertMultiLineStrippedEqual(self, """ + 0 + 1 + 2 + 1 + 2 + 3 + 2 + 3 + 4 + 3 + 4 + 5 + 4 + 5 + 6 + """, testing_stream.output) def test_tap_vmap_while(self): """Vmap of while.""" def func(x): # like max(x, 2) - x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream) + x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) x2 = lax.while_loop( lambda x: x < 2, lambda x: hcb.id_print( - x + 1, where="body:x+1", output_stream=testing_stream), x1) - res = hcb.id_print(x2, where="after:x", output_stream=testing_stream) + x + 1, where="body:x+1", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG), x1) + res = hcb.id_print(x2, where="after:x", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return res inputs = np.arange(5, dtype=np.int32) @@ -1121,72 +1199,93 @@ def func(x): jax.jit(jax.vmap(func))(inputs), check_dtypes=False) hcb.barrier_wait() - assertMultiLineStrippedEqual( - self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: before:x - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: after:x - [2 2 2 3 4]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual( + self, """ + transforms: [('batch', {'batch_dims': (0,)})] where: before:x + [0 1 2 3 4] + transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 + [1 2 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1 + [2 3 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: after:x + [2 2 2 3 4]""", testing_stream.output) + else: + pass # order of vmaps is not guaranteed def test_tap_vmap_while_tap_cond(self): """Vmap of while, with a tap in the conditional.""" def func(x): # like max(x, 2) - x1 = hcb.id_print(x, where="1", output_stream=testing_stream) + x1 = hcb.id_print(x, where="1", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c", - output_stream=testing_stream), + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG), lambda x: hcb.id_print(x + 1, where="w_b", - output_stream=testing_stream), + output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG), x1) - res = hcb.id_print(x2, where="3", output_stream=testing_stream) + res = hcb.id_print(x2, where="3", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return res inputs = np.arange(5, dtype=np.int32) res = jax.jit(jax.vmap(func))(inputs) hcb.barrier_wait() self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) - assertMultiLineStrippedEqual(self, """ - transforms: [('batch', {'batch_dims': (0,)})] where: 1 - [0 1 2 3 4] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True True False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [1 2 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [ True False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: w_b - [2 3 3 4 5] - transforms: [('batch', {'batch_dims': (0,)})] where: w_c - [False False False False False] - transforms: [('batch', {'batch_dims': (0,)})] where: 3 - [2 2 2 3 4]""", testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + assertMultiLineStrippedEqual(self, """ + transforms: [('batch', {'batch_dims': (0,)})] where: 1 + [0 1 2 3 4] + transforms: [('batch', {'batch_dims': (0,)})] where: w_c + [ True True False False False] + transforms: [('batch', {'batch_dims': (0,)})] where: w_b + [1 2 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: w_c + [ True False False False False] + transforms: [('batch', {'batch_dims': (0,)})] where: w_b + [2 3 3 4 5] + transforms: [('batch', {'batch_dims': (0,)})] where: w_c + [False False False False False] + transforms: [('batch', {'batch_dims': (0,)})] where: 3 + [2 2 2 3 4]""", testing_stream.output) + else: + pass # order of vmap is not guaranteed def test_tap_transforms_doc(self): # Examples from the documentation def power3(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return y * x print(f"impl = {power3(3.)}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1197,32 +1296,41 @@ def print_tangents(arg): @print_tangents.defjvp def print_tangents_jvp(primals, tangents): arg_dot, = tangents - hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream) + hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return primals, tangents def power3_with_tangents(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) print_tangents((x, y)) return y * x print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. ) - what: tangents - ( 0.1 0.6 )""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. ) + what: tangents + ( 0.1 0.6 )""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() print(f"grad = {jax.grad(power3)(3.)}") hcb.barrier_wait() # Only the primals by default - expected = """ - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1236,7 +1344,8 @@ def print_cotangents_fwd(arg): return print_cotangents(arg), None # f_bwd: (residual, CT b) -> [CT a] def print_cotangents_bwd(residual, ct_b): - hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) + hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return ct_b, print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) @@ -1244,18 +1353,26 @@ def print_cotangents_bwd(residual, ct_b): def power3_with_cotangents(x): y = x * x # Print both 'x' and 'x^2'. Must pack as a tuple. - hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream) + hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) # Must use the output of print_cotangents (x1, y1) = print_cotangents((x, y)) return y1 * x1 print(f"grad = {jax.grad(power3_with_cotangents)(3.)}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. ) - what: cotangents - ( 9. 3. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. ) + what: cotangents + ( 9. 3. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 ) + what: cotangents + ( 9.0 3.0 )""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() @@ -1263,43 +1380,82 @@ def power3_with_cotangents(x): print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}") hcb.barrier_wait() - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] )""" + else: + expected = """ + what: x,x^2 + ( 2.0 4.0 ) + what: x,x^2 + ( 3.0 9.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}") hcb.barrier_wait() - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] )""" + else: + expected = """ + what: x,x^2 + ( 2.0 4.0 ) + what: x,x^2 + ( 3.0 9.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}") hcb.barrier_wait() - expected = """ - transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 - ( [2. 3.] [4. 9.] ) - transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents - ( [4. 9.] [2. 3.] )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 + ( [2. 3.] [4. 9.] ) + transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents + ( [4. 9.] [2. 3.] )""" + else: + expected = """ + what: x,x^2 + ( 2.0 4.0 ) + what: x,x^2 + ( 3.0 9.0 ) + what: cotangents + ( 4.0 2.0 ) + what: cotangents + ( 9.0 3.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}") hcb.barrier_wait() - expected = """ - what: x,x^2 - ( 3. 9. ) - what: x,x^2 - ( 27. 729. ) - what: x,x^2 - ( 3. 9. )""" + if hcb._HOST_CALLBACK_LEGACY.value: + expected = """ + what: x,x^2 + ( 3. 9. ) + what: x,x^2 + ( 27. 729. ) + what: x,x^2 + ( 3. 9. )""" + else: + expected = """ + what: x,x^2 + ( 3.0 9.0 ) + what: x,x^2 + ( 27.0 729.0 ) + what: x,x^2 + ( 3.0 9.0 ) + """ self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() def test_tap_pmap(self): + self.supported_only_in_legacy_mode() if len(local_devices()) < 2: raise SkipTest("test requires at least 2 devices") @@ -1326,6 +1482,7 @@ def power3(x): ( 4 16 )""") def test_tap_pmap_vmap(self): + self.supported_only_in_legacy_mode() # A matrix M[ij] = i * 10 + j nr_devices = len(local_devices()) shape = (nr_devices, 3) @@ -1353,6 +1510,7 @@ def fun1(x, do_print=False): # x: i32 def test_tap_pmap_pmap_vmap(self): # A matrix M[ijk] = i * 100 + j * 10 + k + self.supported_only_in_legacy_mode() nr_devices = len(local_devices()) if nr_devices % 2 != 0: raise SkipTest("test works only on even number of devices") @@ -1386,6 +1544,7 @@ def fun1(x, do_print=False): # x: f32 def test_tap_pmap_pmap_extra(self): """pmap of a pmap surrounded by extra code.""" # A matrix M[ij] = i * 10 + j + self.supported_only_in_legacy_mode() nr_devices = len(local_devices()) if nr_devices != 2: raise SkipTest("test works only on 2 devices") @@ -1419,6 +1578,7 @@ def fun(xv, do_print=False): [[203.00 205.00 207.00]]""") def test_tap_jvp_pmap_vmap(self): + self.supported_only_in_legacy_mode() # A matrix M[ijk] = i * 100 + j * 10 * k nr_devices = len(local_devices()) shape = (nr_devices, 2, 3) @@ -1445,6 +1605,7 @@ def fun(xv, do_print=False): [220.00 222.00 224.00]]""") def test_tap_vmap_pmap(self): + self.supported_only_in_legacy_mode() # A matrix M[ijk] = i * 100 + j * 10 * k nr_devices = len(local_devices()) shape = (2, nr_devices, 3) @@ -1472,6 +1633,7 @@ def fun(xv, do_print=False): @ignore_jit_of_pmap_warning() def test_tap_jit_pmap_extra(self): """jit of a pmap surrounded by extra code.""" + self.supported_only_in_legacy_mode() # A matrix M[ij] = i * 10 + j nr_devices = len(local_devices()) assert nr_devices in (1, 2) @@ -1540,6 +1702,7 @@ def fun2(cond, xv, do_print=False): @jtu.sample_product(device_index=[0, 1]) def test_tap_pjit(self, device_index=0): + self.supported_only_in_legacy_mode() if (device_index != 0 and not hcb._HOST_CALLBACK_OUTFEED.value and jtu.test_device_matches(["cpu"])): @@ -1589,7 +1752,7 @@ def fun1(x): def test_tap_scan_custom_jvp(self): """custom JVP, inside scan. This exercises the custom_jvp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_jvp def f(x): return x * hcb.id_print(x, output_stream=testing_stream, what="x") @@ -1633,7 +1796,7 @@ def g(x): def test_tap_scan_custom_vjp(self): """custom VJP, inside scan. This exercises the custom_vjp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_vjp def f(x): return x * hcb.id_print(x, output_stream=testing_stream, what="x") @@ -1773,7 +1936,7 @@ def test_tap_odeint(self): from jax.experimental.ode import odeint def f(x, t, k): - x = hcb.id_print(x) + x = hcb.id_print(x, callback_flavor=hcb.CallbackFlavor.DEBUG) return -k * x def loss(k=1.0): @@ -1785,7 +1948,8 @@ def loss(k=1.0): def test_tap_remat_0(self): def f(i, k): - x = hcb.id_print(k + i, output_stream=testing_stream) + x = hcb.id_print(k + i, output_stream=testing_stream, + callback_flavor=hcb.CallbackFlavor.DEBUG) return k * x def loss(k): @@ -1804,6 +1968,7 @@ def loss(k): use_remat=["old", "new", "none"], ) def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): + self.supported_only_in_legacy_mode() if use_remat == "old": raise SkipTest() def f(x): @@ -1880,6 +2045,10 @@ def tearDown(self) -> None: hcb.barrier_wait("HostCallbackCallTest.tearDown") super().tearDown() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + def call_log_testing_stream(self, func, arg, *, result_shape, name=""): """Call `func` and log inputs and outputs to the testing stream""" @@ -1916,6 +2085,7 @@ def fun(x): with jtu.count_primitive_compiles() as count: for _ in range(3): self.assertAllClose(2 * arg, fun(arg)) + r = jax.make_jaxpr(fun)(arg) self.assertEqual(count[0], 1) @jtu.sample_product( @@ -2124,6 +2294,7 @@ def fun2(m): helper_print_optimized_hlo(fun2, m) def test_call_with_device(self): + self.supported_only_in_legacy_mode() def callback_func(x, device=None): testing_stream.write(f"device: {device}\n Called with {x}") return x @@ -2139,6 +2310,7 @@ def func(x): Called with 3.00""") def test_call_pmap(self): + self.supported_only_in_legacy_mode() # Works for 1 or 2 devices def callback_func(x, device=None): testing_stream.write(f"device: {device}\n Called with {x}") @@ -2163,10 +2335,14 @@ def test_call_vmap(self): def f_outside(x): return x def fun(x): - return hcb.call(f_outside, x, result_shape=x) + return hcb.call(f_outside, x, result_shape=x, + callback_flavor=hcb.CallbackFlavor.PURE) - with self.assertRaisesRegex(NotImplementedError, - "batching rules are implemented only for id_tap, not for call"): + if hcb._HOST_CALLBACK_LEGACY.value: + with self.assertRaisesRegex(NotImplementedError, + "batching rules are implemented only for id_tap, not for call"): + jax.vmap(fun)(np.ones((2, 3))) + else: jax.vmap(fun)(np.ones((2, 3))) @jtu.sample_product(device_index=[0, 1]) @@ -2256,6 +2432,7 @@ def helper_check_callback_errors(self, thunk: Callable, hcb.barrier_wait("Waiting for error") def test_call_error_callback_throws_exception(self): + self.supported_only_in_legacy_mode() def f_outside(x): raise ValueError("user exception") def fun(x): @@ -2265,6 +2442,7 @@ def fun(x): "ValueError: user exception") def test_call_error_callback_returns_unexpected_shape(self): + self.supported_only_in_legacy_mode() def fun(x): return hcb.call(lambda x: (x, x), x, result_shape=x) @@ -2272,6 +2450,7 @@ def fun(x): "Callback func .* should have returned a result with pytree") def test_call_error_then_compute(self): + self.supported_only_in_legacy_mode() # Continue computation on device after error def f_outside(x): raise ValueError("user exception") @@ -2283,7 +2462,9 @@ def fun(x): "ValueError: user exception") -def call_jax_other_device(jax_outside_fun, arg, *, device): +def call_jax_other_device( + jax_outside_fun, arg, *, device, + callback_flavor: hcb.CallbackFlavor = hcb.CallbackFlavor.IO_CALLBACK): """Calls a JAX function on a specific device with simple support for reverse AD. Functions whose name starts with "jax_outside" are called on another device, @@ -2296,7 +2477,8 @@ def run_jax_outside_fun(arg): @jax.custom_vjp def make_call(arg): return hcb.call(run_jax_outside_fun, arg, - result_shape=jax.eval_shape(jax_outside_fun, arg)) + result_shape=jax.eval_shape(jax_outside_fun, arg), + callback_flavor=callback_flavor) # Define the fwd and bwd custom_vjp functions def make_call_vjp_fwd(arg): @@ -2323,6 +2505,8 @@ class CallJaxTest(jtu.JaxTestCase): """Tests using `call_jax_other_device`.""" def setUp(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") if xla_bridge.using_pjrt_c_api(): @@ -2337,6 +2521,7 @@ def setUp(self): self.outside_device = jax.devices("cpu")[1] super().setUp() + def test_jax_impl(self): def f_jax(x): return jnp.sin(x) @@ -2404,6 +2589,10 @@ def setUp(self): raise SkipTest("host_callback not implemented in PJRT C API") super().setUp() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + def assertRewrite(self, expected: str, func: Callable, args: Sequence, has_input_token=True, has_output_token=True): """Check that the rewrite of func(*args) matches expected.""" @@ -2624,7 +2813,7 @@ def func(x): def test_scan_custom_jvp(self): """custom JVP, inside scan. This exercises the custom_jvp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_jvp def f(x): return x * hcb.id_print(x) @@ -2706,7 +2895,7 @@ def g(x): def test_scan_custom_vjp(self): """custom VJP, inside scan. This exercises the custom_vjp_call_jaxpr primitives.""" - + self.supported_only_in_legacy_mode() @jax.custom_vjp def f(x): return x * hcb.id_print(x) @@ -2849,6 +3038,7 @@ def step(acc, step_nr): in (c, d, e) }""", tap_scalar, [np.int32(3)]) def test_pmap(self): + self.supported_only_in_legacy_mode() def f(xv): jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)), axis_name="i")(xv) diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index c8858e14084c..5702ae1e78ad 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -53,7 +53,8 @@ def tf_to_numpy(t): return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy, tf_fun(arg)), - arg, result_shape=result_shape) + arg, result_shape=result_shape, + callback_flavor=hcb.CallbackFlavor.DEBUG) def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape): @@ -166,12 +167,17 @@ def setUp(self): raise unittest.SkipTest("host_callback not implemented in PJRT C API") super().setUp() + def supported_only_in_legacy_mode(self): + if not hcb._HOST_CALLBACK_LEGACY.value: + self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") + @parameterized.named_parameters( dict( testcase_name=f"_{ad=}", ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys()) def test_impl(self, ad="simple"): + self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] def f_jax(x): @@ -192,13 +198,14 @@ def f_outside(x): for ad in CALL_TF_IMPLEMENTATIONS.keys() if ad != "none") def test_grad(self, ad="simple"): + self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] def f_jax(x): return 3. * jnp.sin(2. * x) def f_outside(x): - return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x) + return 3. * call_tf(tf.math.sin, 2. * x, result_shape=np.asarray(x)) x = 4. self.assertAllClose(f_jax(x), f_outside(x)) @@ -207,6 +214,7 @@ def f_outside(x): self.assertAllClose(jax.grad(f_jax)(x), grad_f) def test_grad_pytree(self): + self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad def f_jax(xy): @@ -217,7 +225,8 @@ def f_outside(xy): dict_ab = call_tf( lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]), xy, - result_shape=dict(a=xy[0], b=xy[1])) + result_shape=dict(a=jax.ShapeDtypeStruct((), np.float32), + b=jax.ShapeDtypeStruct((), np.float32))) return 3. * dict_ab["a"] + 4. * dict_ab["b"] xy = (5., 6.) @@ -231,6 +240,7 @@ def f_outside(xy): degree=degree) for degree in [1, 2, 3, 4]) def test_higher_order_grad(self, degree=4): + self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad def f_jax(x): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 2329fe65e052..ec5945be6af8 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -247,7 +247,7 @@ def f(): jax.effects_barrier() @with_pure_and_io_callbacks - def test_callback_with_wrong_dtype_outputs(self, *, callback=io_callback_ordered): + def test_callback_with_wrong_dtype_outputs(self, *, callback): def _cb(): return np.array([1], np.float64)