Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate jax.experimental.host_callback in favor of JAX external callbacks #20385

Open
gnecula opened this issue Mar 22, 2024 · 2 comments
Open
Assignees
Labels
enhancement New feature or request

Comments

@gnecula
Copy link
Collaborator

gnecula commented Mar 22, 2024

We have marked the host_callback APIs deprecated on March 21, 2024 (JAX version 0.4.26). They will be removed in October 2024. Users should use instead the new JAX external callbacks.

Quick temporary migration

As of October 1st, 2024 (JAX version 0.4.34) if you use the jax.experimental.host_callback APIs they will be implemented in terms of jax.experimental.io_callback. This is controlled by the configuration variable --jax_host_callback_legacy=False (or the environment variable JAX_HOST_CALLBACK_LEGACY=False.

For a very limited time, you can obtain the old behavior by setting the configuration variable to True.
Very soon this configuration flag will be removed, so it is best to take the time to do the migration as explained below.

Real migration

It is best to study the different flavors of JAX external callbacks to pick the right one for your use case.

In general io_callback(ordered=True) will have more similar support to the existing host_callback.

In general, you should replace calls to id_tap and call with io_callback, except when you need these calls to work under vmap, grad, jvp, scan, or cond, in which case you should use jax.debug.callback. Note that jax.debug.callback does not support returning values from the callback, so it can be used only in lieu of .id_print or host_callback.id_tap or in lieu of host_callback.call when the result_shape=None.

Known migration issues

  • the tap_with_device option for id_tap and the call_with_device option for call are not supported. You must change the callbacks to not need the device argument. If you use JAX_HOST_CALLBACK_LEGACY=False you will get an error.
  • the transforms argument to the callback called from id_tap is not supported. If you use JAX_HOST_CALLBACK_LEGACY=False the callback will be passed the empty tuple (no transforms).
  • the old host_callback APIs passed np.ndarray objects to the callback. The new JAX external callbacks pass jax.Array. This should be Ok, except that it may lead to a deadlock if the code making the call is already running on CPU, because the callback will try to invoke JAX functions on the arguments and will find the device busy. The solution is to add input = np.array(input) at the start of your callback function.
  • If you attempt to use io_callback(ordered=True) with jax.grad, you will get an error that io_callback does not support JVP. Try to use debug_callback.
  • If you attempt to use io_callback(ordered=True) with jax.pmap you will get an error that ordered effects are not supported under jax.pmap. Try to use ordered=True.

Using io_callback in place of host_callback.call

For example,

from jax.experimental import host_callback
res = host_callback.call(fn, arg, result_shape=result_shape_dtypes)

should be replaced with

from jax.experimental import io_callback
res = io_callback(fn, result_shape_dtypes, arg)

Using io_callback in place of host_callback.id_tap

Similarly, id_tap can be replaced with a io_callback with result_shape_dtypes=None:

 callback = lambda x, transforms: do_something(x)
 res = host_callback.id_tap(callback, x_in)

should be replaced with

  callback = lambda x: do_something(x)
  io_callback(callback, None, x_in)
  res = x_in  # Simulates the return value of `id_tap`

Note that we have removed the transforms callback argument (this is not supported by the new callbacks).

If you use the result parameter with id_tap then you can replace:

results = id_tap(
    lambda arg, transform: done_callback(arg),
    arg,
    result=the_results,
)

with

io_callback(
    lambda arg: done_callback(arg),
    None,
    arg
)
results = the_results

Using jax.debug.print in place of host_callback.id_print

For id_print you should use instead jax.debug.print. E.g.,

id_print(x) can be replaced by debug.print('{}', x).

If you use the name parameter, you can replace
id_print(x, name="my_x") with jax.debug_print('name: my_x\n{}', x).

If you use the output_stream parameter, you can replace:
id_print(x, output_stream=s) by jax.experimental.io_callback(lambda x: s.write(str(x)), None, x).

Using jax.effects_barrier in place of host_callback.barrier_wait

Finally, host_callback.barrier_wait should be replaced with jax.effects_barrier().

Callbacks and jax.vmap

Under vmap the new callbacks behave differently than the host_callback. The latter will make a single call with a vector value, while the new callbacks will behave like a loop, and will make separate calls for each element in the vmap. For example, the code

def host_fn(x):
  print(x)

def fn(x):
  res = 2 * x
  id_tap(host_fn, res)
  return res

jax.vmap(fn)(np.arange(3))

makes one call to host_fn with the vector [0, 2, 4], and if we replace id_tap(host_fn, res) with jax.debug.callback(host_fn, res) we will get 3 separate calls with 0, 2, and 4, respectively.

@gnecula gnecula added the enhancement New feature or request label Mar 22, 2024
@gnecula gnecula self-assigned this Mar 22, 2024
copybara-service bot pushed a commit that referenced this issue Mar 25, 2024
The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

Most of the changes here have to do with the fact that io_callback does not pass the `device` to the callback. Fortunately, it seems that this code uses the device argument only for logging. I removed all uses of `device`.

PiperOrigin-RevId: 618402363
copybara-service bot pushed a commit that referenced this issue Mar 25, 2024
The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

Most of the changes here have to do with the fact that io_callback does not pass the `device` to the callback. Fortunately, it seems that this code uses the device argument only for logging. I removed all uses of `device`.

PiperOrigin-RevId: 618402363
copybara-service bot pushed a commit that referenced this issue Mar 29, 2024
…_outfeed_receiver.

The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

PiperOrigin-RevId: 620220346
copybara-service bot pushed a commit that referenced this issue Mar 29, 2024
The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

The other API entry points have been marked as deprecated already, but barrier_wait was missed.

PiperOrigin-RevId: 620222377
copybara-service bot pushed a commit that referenced this issue Mar 29, 2024
…_outfeed_receiver.

The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

PiperOrigin-RevId: 620220346
copybara-service bot pushed a commit that referenced this issue Mar 29, 2024
The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

The other API entry points have been marked as deprecated already, but barrier_wait was missed.

PiperOrigin-RevId: 620222377
copybara-service bot pushed a commit that referenced this issue Mar 29, 2024
…_outfeed_receiver.

The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

PiperOrigin-RevId: 620237081
copybara-service bot pushed a commit that referenced this issue Mar 29, 2024
The jax.experimental.host_callback module is deprecated and will be removed.

See #20385.

The other API entry points have been marked as deprecated already, but barrier_wait was missed.

PiperOrigin-RevId: 620237286
gnecula added a commit to gnecula/jax that referenced this issue Apr 1, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.
gnecula added a commit to gnecula/jax that referenced this issue Apr 1, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 1, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 1, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 2, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 2, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 3, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 4, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 4, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 4, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 4, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 5, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 5, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 5, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
gnecula added a commit to gnecula/jax that referenced this issue Apr 5, 2024
…ack.

The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue jax-ml#20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
copybara-service bot pushed a commit that referenced this issue Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682659677
copybara-service bot pushed a commit that referenced this issue Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See #20385 for a discussion.

PiperOrigin-RevId: 682830525
copybara-service bot pushed a commit to google-research/t5x that referenced this issue Oct 6, 2024
…ental.io_callback.

The jax.experimental.host_callback module is deprecated and will be removed.

See jax-ml/jax#20385.

PiperOrigin-RevId: 682885092
copybara-service bot pushed a commit to google-research/t5x that referenced this issue Oct 6, 2024
…ental.io_callback.

The jax.experimental.host_callback module is deprecated and will be removed.

See jax-ml/jax#20385.

PiperOrigin-RevId: 682885092
copybara-service bot pushed a commit to google-research/google-research that referenced this issue Oct 6, 2024
…ental.io_callback.

The jax.experimental.host_callback module is deprecated and will be removed.

See jax-ml/jax#20385.

PiperOrigin-RevId: 682943395
copybara-service bot pushed a commit to google-research/t5x that referenced this issue Oct 7, 2024
…ental.io_callback.

The jax.experimental.host_callback module is deprecated and will be removed.

See jax-ml/jax#20385.

PiperOrigin-RevId: 682885092
copybara-service bot pushed a commit to google-research/t5x that referenced this issue Oct 7, 2024
…ental.io_callback.

The jax.experimental.host_callback module is deprecated and will be removed.

See jax-ml/jax#20385.

PiperOrigin-RevId: 683196778
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See #20385.

PiperOrigin-RevId: 682837548
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See #20385.

PiperOrigin-RevId: 682837548
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See #20385.

PiperOrigin-RevId: 683564340
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants