Skip to content

Commit

Permalink
Remove remaining implementations of jax.experimental.host_callback APIs.
Browse files Browse the repository at this point in the history
See #20385.

PiperOrigin-RevId: 682837548
  • Loading branch information
gnecula authored and Google-ML-Automation committed Oct 8, 2024
1 parent 6a958b9 commit dcd40e6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 171 deletions.
63 changes: 7 additions & 56 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backwards compatibility shim for the deprecated host_callback APIs.
"""Backwards compatibility shims for the deprecated host_callback APIs.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
Expand All @@ -22,65 +22,16 @@
"""

from __future__ import annotations

from collections.abc import Callable
import logging
import warnings

import jax
from jax.experimental import io_callback


logger = logging.getLogger(__name__)


# We keep a shim for host_callback.call because it is still used in a few
# places in google.
def call(callback_func: Callable,
arg,
*,
result_shape=None,
call_with_device=False,
device_index=0,
callback_flavor=None):
"""Make a call to the host, and expect a result.
.. warning::
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
See https://github.com/jax-ml/jax/issues/20385.
"""
warnings.warn("""The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
new JAX external callbacks (https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
See https://github.com/jax-ml/jax/issues/20385
""", DeprecationWarning, stacklevel=2)
if callback_flavor is not None:
raise NotImplementedError(
"host_callback.call is only supported with the IO_CALLBACK flavor.")
if call_with_device:
raise NotImplementedError(
"host_callback.call is only supported with the call_with_device=False.")
callback_device = jax.local_devices()[device_index]
sharding = jax.sharding.SingleDeviceSharding(callback_device)
return io_callback(callback_func, result_shape, arg,
sharding=sharding,
ordered=True)

import typing

if typing.TYPE_CHECKING:
def id_tap(tap_func,
arg,
*,
result=None,
tap_with_device=False,
device_index=0,
callback_flavor=None,
**kwargs):
# We keep a couple of shims until a later CL to avoid breaking type checking
# in a few places in google.
def id_tap(*_, **__):
raise NotImplementedError(
"host_callback.id_tap is no longer supported. "
"jax.experimental.host_callback is no longer supported. "
"See https://github.com/jax-ml/jax/issues/20385"
)
call = id_tap

del typing
11 changes: 0 additions & 11 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ jax_multiplatform_test(
name = "infeed_test",
srcs = ["infeed_test.py"],
deps = [
"//jax:experimental_host_callback",
],
)

Expand Down Expand Up @@ -1142,16 +1141,6 @@ jax_multiplatform_test(
deps = ["//jax:ode"],
)

jax_multiplatform_test(
name = "host_callback_test",
srcs = ["host_callback_test.py"],
main = "host_callback_test.py",
deps = [
"//jax:experimental",
"//jax:ode",
],
)

jax_multiplatform_test(
name = "key_reuse_test",
srcs = ["key_reuse_test.py"],
Expand Down
104 changes: 0 additions & 104 deletions tests/host_callback_test.py

This file was deleted.

0 comments on commit dcd40e6

Please sign in to comment.