From 09df1ac6898c4962f71509fef71f9bd2c80809c7 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 6 Oct 2024 01:48:09 -0700 Subject: [PATCH] Remove remaining implementations of jax.experimental.host_callback.call. The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future. See https://github.com/google/jax/issues/20385. PiperOrigin-RevId: 682837548 --- jax/experimental/host_callback.py | 68 +++---------------- tests/BUILD | 11 ---- tests/host_callback_test.py | 104 ------------------------------ 3 files changed, 8 insertions(+), 175 deletions(-) delete mode 100644 tests/host_callback_test.py diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index f6f51ba5796a..7d60f62e230f 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Backwards compatibility shim for the deprecated host_callback APIs. +"""Backwards compatibility shims for the deprecated host_callback APIs. .. warning:: The host_callback APIs are deprecated as of March 20, 2024. @@ -23,64 +23,12 @@ from __future__ import annotations -from collections.abc import Callable -import logging -import warnings +def call(*_, **__): + raise NotImplementedError( + "jax.experimental.host_callback has been deprecated since March 2024 and " + "is now no longer supported. " + "See https://github.com/jax-ml/jax/issues/20385" + ) -import jax -from jax.experimental import io_callback - -logger = logging.getLogger(__name__) - - -# We keep a shim for host_callback.call because it is still used in a few -# places in google. -def call(callback_func: Callable, - arg, - *, - result_shape=None, - call_with_device=False, - device_index=0, - callback_flavor=None): - """Make a call to the host, and expect a result. - - .. warning:: - The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - `new JAX external callbacks `_ - See https://github.com/jax-ml/jax/issues/20385. - """ - warnings.warn("""The host_callback APIs are deprecated as of March 20, 2024. - The functionality is subsumed by the - new JAX external callbacks (https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). - See https://github.com/jax-ml/jax/issues/20385 - """, DeprecationWarning, stacklevel=2) - if callback_flavor is not None: - raise NotImplementedError( - "host_callback.call is only supported with the IO_CALLBACK flavor.") - if call_with_device: - raise NotImplementedError( - "host_callback.call is only supported with the call_with_device=False.") - callback_device = jax.local_devices()[device_index] - sharding = jax.sharding.SingleDeviceSharding(callback_device) - return io_callback(callback_func, result_shape, arg, - sharding=sharding, - ordered=True) - -import typing -if typing.TYPE_CHECKING: - def id_tap(tap_func, - arg, - *, - result=None, - tap_with_device=False, - device_index=0, - callback_flavor=None, - **kwargs): - raise NotImplementedError( - "host_callback.id_tap is no longer supported. " - "See https://github.com/jax-ml/jax/issues/20385" - ) - -del typing +id_tap = call diff --git a/tests/BUILD b/tests/BUILD index e444c01d0ac5..087372eea5a6 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -330,7 +330,6 @@ jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], deps = [ - "//jax:experimental_host_callback", ], ) @@ -1142,16 +1141,6 @@ jax_multiplatform_test( deps = ["//jax:ode"], ) -jax_multiplatform_test( - name = "host_callback_test", - srcs = ["host_callback_test.py"], - main = "host_callback_test.py", - deps = [ - "//jax:experimental", - "//jax:ode", - ], -) - jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py deleted file mode 100644 index 42c4496643bf..000000000000 --- a/tests/host_callback_test.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from unittest import SkipTest - -from absl.testing import absltest - -import jax -from jax.experimental import host_callback as hcb -from jax._src import xla_bridge -from jax._src import test_util as jtu - -import numpy as np - -jax.config.parse_flags_with_absl() - - -class HostCallbackCallTest(jtu.JaxTestCase): - """Tests for hcb.call""" - - def setUp(self): - # skipping here skips teardown, so do this before super().setUp(). - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="backend and device argument")) - - def tearDown(self) -> None: - jax.effects_barrier() - super().tearDown() - - def test_call_simple(self): - - def f_outside(x): - return 2 * x - - def fun(x): - y = hcb.call(f_outside, x + 1, result_shape=x) - return 3 * (1 + y) - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg)) - - - @jtu.sample_product( - dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_], - ) - def test_call_types(self, dtype=np.float64): - - def f_outside(x): - # Use x + x to ensure that the result type is the same - return x + x - - def fun(x): - return hcb.call(f_outside, x + x, result_shape=x) - - arg = np.arange(24, dtype=dtype).reshape((2, 3, 4)) - self.assertAllClose(arg + arg + arg + arg, fun(arg), check_dtypes=True) - - def test_call_types_bool(self, dtype=np.float64): - - def f_outside(x): - return np.invert(x) - - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - arg = self.rng().choice(a=[True, False], size=(2, 3, 4)) - self.assertAllClose(np.invert(arg), fun(arg)) - - def test_call_tuples(self): - - def f_outside(args): - x, y = args - return y, x # Swap the tuple - - def fun(x): - xy = hcb.call(f_outside, (x, x + 1), result_shape=(x, x)) - return 2 * xy[0] + 3 * xy[1] - - arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) - self.assertAllClose(2 * (arg + 1) + 3 * arg, fun(arg)) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader())