diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index f6f51ba5796a..9a684752ca8f 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Backwards compatibility shim for the deprecated host_callback APIs. +"""Backwards compatibility shims for the deprecated host_callback APIs. .. warning:: The host_callback APIs are deprecated as of March 20, 2024. @@ -22,65 +22,16 @@ """ from __future__ import annotations - -from collections.abc import Callable -import logging -import warnings - -import jax -from jax.experimental import io_callback - - -logger = logging.getLogger(__name__) - - -# We keep a shim for host_callback.call because it is still used in a few -# places in google. -def call(callback_func: Callable, - arg, - *, - result_shape=None, - call_with_device=False, - device_index=0, - callback_flavor=None): - """Make a call to the host, and expect a result. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - """ - warnings.warn("""The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - new JAX external callbacks (https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). - See https://github.com/jax-ml/jax/issues/20385 - """, DeprecationWarning, stacklevel=2) - if callback_flavor is not None: - raise NotImplementedError( - "host_callback.call is only supported with the IO_CALLBACK flavor.") - if call_with_device: - raise NotImplementedError( - "host_callback.call is only supported with the call_with_device=False.") - callback_device = jax.local_devices()[device_index] - sharding = jax.sharding.SingleDeviceSharding(callback_device) - return io_callback(callback_func, result_shape, arg, - sharding=sharding, - ordered=True) - import typing + if typing.TYPE_CHECKING: - def id_tap(tap_func, - arg, - *, - result=None, - tap_with_device=False, - device_index=0, - callback_flavor=None, - **kwargs): + # We keep a couple of shims until a later CL to avoid breaking type checking + # in a few places in google. + def id_tap(*_, **__): raise NotImplementedError( - "host_callback.id_tap is no longer supported. " + "jax.experimental.host_callback is no longer supported. " "See https://github.com/jax-ml/jax/issues/20385" ) + call = id_tap del typing diff --git a/tests/BUILD b/tests/BUILD index 188b5ae814d7..df517eeed72f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -330,7 +330,6 @@ jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], deps = [ - "//jax:experimental_host_callback", ], ) @@ -1142,16 +1141,6 @@ jax_multiplatform_test( deps = ["//jax:ode"], ) -jax_multiplatform_test( - name = "host_callback_test", - srcs = ["host_callback_test.py"], - main = "host_callback_test.py", - deps = [ - "//jax:experimental", - "//jax:ode", - ], -) - jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py deleted file mode 100644 index 42c4496643bf..000000000000 --- a/tests/host_callback_test.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from unittest import SkipTest - -from absl.testing import absltest - -import jax -from jax.experimental import host_callback as hcb -from jax._src import xla_bridge -from jax._src import test_util as jtu - -import numpy as np - -jax.config.parse_flags_with_absl() - - -class HostCallbackCallTest(jtu.JaxTestCase): - """Tests for hcb.call""" - - def setUp(self): - # skipping here skips teardown, so do this before super().setUp(). - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="backend and device argument")) - - def tearDown(self) -> None: - jax.effects_barrier() - super().tearDown() - - def test_call_simple(self): - - def f_outside(x): - return 2 * x - - def fun(x): - y = hcb.call(f_outside, x + 1, result_shape=x) - return 3 * (1 + y) - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg)) - - - @jtu.sample_product( - dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_], - ) - def test_call_types(self, dtype=np.float64): - - def f_outside(x): - # Use x + x to ensure that the result type is the same - return x + x - - def fun(x): - return hcb.call(f_outside, x + x, result_shape=x) - - arg = np.arange(24, dtype=dtype).reshape((2, 3, 4)) - self.assertAllClose(arg + arg + arg + arg, fun(arg), check_dtypes=True) - - def test_call_types_bool(self, dtype=np.float64): - - def f_outside(x): - return np.invert(x) - - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - arg = self.rng().choice(a=[True, False], size=(2, 3, 4)) - self.assertAllClose(np.invert(arg), fun(arg)) - - def test_call_tuples(self): - - def f_outside(args): - x, y = args - return y, x # Swap the tuple - - def fun(x): - xy = hcb.call(f_outside, (x, x + 1), result_shape=(x, x)) - return 2 * xy[0] + 3 * xy[1] - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - self.assertAllClose(2 * (arg + 1) + 3 * arg, fun(arg)) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader())