-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate jax.experimental.host_callback in favor of JAX external callbacks #20385
Labels
enhancement
New feature or request
Comments
copybara-service bot
pushed a commit
that referenced
this issue
Mar 25, 2024
The jax.experimental.host_callback module is deprecated and will be removed. See #20385. Most of the changes here have to do with the fact that io_callback does not pass the `device` to the callback. Fortunately, it seems that this code uses the device argument only for logging. I removed all uses of `device`. PiperOrigin-RevId: 618402363
copybara-service bot
pushed a commit
that referenced
this issue
Mar 25, 2024
The jax.experimental.host_callback module is deprecated and will be removed. See #20385. Most of the changes here have to do with the fact that io_callback does not pass the `device` to the callback. Fortunately, it seems that this code uses the device argument only for logging. I removed all uses of `device`. PiperOrigin-RevId: 618402363
copybara-service bot
pushed a commit
that referenced
this issue
Mar 29, 2024
…_outfeed_receiver. The jax.experimental.host_callback module is deprecated and will be removed. See #20385. PiperOrigin-RevId: 620220346
copybara-service bot
pushed a commit
that referenced
this issue
Mar 29, 2024
The jax.experimental.host_callback module is deprecated and will be removed. See #20385. The other API entry points have been marked as deprecated already, but barrier_wait was missed. PiperOrigin-RevId: 620222377
copybara-service bot
pushed a commit
that referenced
this issue
Mar 29, 2024
…_outfeed_receiver. The jax.experimental.host_callback module is deprecated and will be removed. See #20385. PiperOrigin-RevId: 620220346
copybara-service bot
pushed a commit
that referenced
this issue
Mar 29, 2024
The jax.experimental.host_callback module is deprecated and will be removed. See #20385. The other API entry points have been marked as deprecated already, but barrier_wait was missed. PiperOrigin-RevId: 620222377
copybara-service bot
pushed a commit
that referenced
this issue
Mar 29, 2024
…_outfeed_receiver. The jax.experimental.host_callback module is deprecated and will be removed. See #20385. PiperOrigin-RevId: 620237081
copybara-service bot
pushed a commit
that referenced
this issue
Mar 29, 2024
The jax.experimental.host_callback module is deprecated and will be removed. See #20385. The other API entry points have been marked as deprecated already, but barrier_wait was missed. PiperOrigin-RevId: 620237286
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 1, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 1, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 1, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 1, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 2, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 2, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 3, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 4, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 4, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 4, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 4, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 5, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 5, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 5, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
gnecula
added a commit
to gnecula/jax
that referenced
this issue
Apr 5, 2024
…ack. The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue jax-ml#20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
copybara-service bot
pushed a commit
that referenced
this issue
Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 5, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682659677
copybara-service bot
pushed a commit
that referenced
this issue
Oct 6, 2024
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks. See #20385 for a discussion. PiperOrigin-RevId: 682830525
copybara-service bot
pushed a commit
to google-research/t5x
that referenced
this issue
Oct 6, 2024
…ental.io_callback. The jax.experimental.host_callback module is deprecated and will be removed. See jax-ml/jax#20385. PiperOrigin-RevId: 682885092
copybara-service bot
pushed a commit
to google-research/t5x
that referenced
this issue
Oct 6, 2024
…ental.io_callback. The jax.experimental.host_callback module is deprecated and will be removed. See jax-ml/jax#20385. PiperOrigin-RevId: 682885092
copybara-service bot
pushed a commit
to google-research/google-research
that referenced
this issue
Oct 6, 2024
…ental.io_callback. The jax.experimental.host_callback module is deprecated and will be removed. See jax-ml/jax#20385. PiperOrigin-RevId: 682943395
copybara-service bot
pushed a commit
to google-research/t5x
that referenced
this issue
Oct 7, 2024
…ental.io_callback. The jax.experimental.host_callback module is deprecated and will be removed. See jax-ml/jax#20385. PiperOrigin-RevId: 682885092
copybara-service bot
pushed a commit
to google-research/t5x
that referenced
this issue
Oct 7, 2024
…ental.io_callback. The jax.experimental.host_callback module is deprecated and will be removed. See jax-ml/jax#20385. PiperOrigin-RevId: 683196778
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
See #20385. PiperOrigin-RevId: 682837548
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
See #20385. PiperOrigin-RevId: 682837548
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
See #20385. PiperOrigin-RevId: 682837548
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
See #20385. PiperOrigin-RevId: 682837548
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future. See #20385. PiperOrigin-RevId: 682837548
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future. See #20385. PiperOrigin-RevId: 682837548
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future. See #20385. PiperOrigin-RevId: 683564340
This was referenced Oct 31, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We have marked the host_callback APIs deprecated on March 21, 2024 (JAX version 0.4.26). They will be removed in October 2024. Users should use instead the new JAX external callbacks.
Quick temporary migration
As of October 1st, 2024 (JAX version 0.4.34) if you use the
jax.experimental.host_callback
APIs they will be implemented in terms of jax.experimental.io_callback. This is controlled by the configuration variable--jax_host_callback_legacy=False
(or the environment variableJAX_HOST_CALLBACK_LEGACY=False
.For a very limited time, you can obtain the old behavior by setting the configuration variable to
True
.Very soon this configuration flag will be removed, so it is best to take the time to do the migration as explained below.
Real migration
It is best to study the different flavors of JAX external callbacks to pick the right one for your use case.
In general
io_callback(ordered=True)
will have more similar support to the existinghost_callback
.In general, you should replace calls to
id_tap
andcall
withio_callback
, except when you need these calls to work undervmap
,grad
,jvp
,scan
, orcond
, in which case you should usejax.debug.callback
. Note thatjax.debug.callback
does not support returning values from the callback, so it can be used only in lieu of.id_print
orhost_callback.id_tap
or in lieu ofhost_callback.call
when theresult_shape=None
.Known migration issues
tap_with_device
option forid_tap
and thecall_with_device
option forcall
are not supported. You must change the callbacks to not need thedevice
argument. If you useJAX_HOST_CALLBACK_LEGACY=False
you will get an error.transforms
argument to the callback called fromid_tap
is not supported. If you useJAX_HOST_CALLBACK_LEGACY=False
the callback will be passed the empty tuple (no transforms).host_callback
APIs passednp.ndarray
objects to the callback. The new JAX external callbacks passjax.Array
. This should be Ok, except that it may lead to a deadlock if the code making the call is already running on CPU, because the callback will try to invoke JAX functions on the arguments and will find the device busy. The solution is to addinput = np.array(input)
at the start of your callback function.io_callback(ordered=True)
withjax.grad
, you will get an error thatio_callback
does not support JVP. Try to usedebug_callback
.io_callback(ordered=True)
withjax.pmap
you will get an error that ordered effects are not supported underjax.pmap
. Try to useordered=True
.Using
io_callback
in place ofhost_callback.call
For example,
should be replaced with
Using
io_callback
in place ofhost_callback.id_tap
Similarly,
id_tap
can be replaced with aio_callback
withresult_shape_dtypes=None
:should be replaced with
Note that we have removed the
transforms
callback argument (this is not supported by the new callbacks).If you use the
result
parameter withid_tap
then you can replace:with
Using
jax.debug.print
in place ofhost_callback.id_print
For
id_print
you should use insteadjax.debug.print
. E.g.,id_print(x)
can be replaced bydebug.print('{}', x)
.If you use the
name
parameter, you can replaceid_print(x, name="my_x")
withjax.debug_print('name: my_x\n{}', x)
.If you use the
output_stream
parameter, you can replace:id_print(x, output_stream=s)
byjax.experimental.io_callback(lambda x: s.write(str(x)), None, x)
.Using
jax.effects_barrier
in place ofhost_callback.barrier_wait
Finally,
host_callback.barrier_wait
should be replaced withjax.effects_barrier()
.Callbacks and
jax.vmap
Under
vmap
the new callbacks behave differently than the host_callback. The latter will make a single call with a vector value, while the new callbacks will behave like a loop, and will make separate calls for each element in the vmap. For example, the codemakes one call to
host_fn
with the vector [0, 2, 4], and if we replaceid_tap(host_fn, res)
withjax.debug.callback(host_fn, res)
we will get 3 separate calls with0
,2
, and4
, respectively.The text was updated successfully, but these errors were encountered: