-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding host_callbacks.id_tap reorders JVP evaluation #3198
Comments
Hi @shoyer Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab with JAX version 0.4.23. Now both the functions import jax
from jax.experimental import host_callback
def f1(x):
y = x ** 2
return y
def f2(x):
y = x ** 2
y = host_callback.id_print(y)
return y
print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))
print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0)) Output:
Since import jax
def f1(x):
y = x ** 2
return y
def f2(x):
y = x ** 2
y = jax.debug.print("{}", y)
return y
print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))
print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0)) Output:
Please find the gist for reference. Thank you. |
id_tap is removed from JAX, so this issue should be moot. |
(Forked from #3127)
Consider the following example:
The function
f2
is exactly the same asf1
, except with the addition ofid_print
. Naively, I would expect these functions to be evaluated in exactly the same order, expect with some extra calls toid_tap
injected. But as we can see from the JAXprs, that isn't what happens:Without id_print, primals are evaluated before tangents. But with id_print, tangents are evaluated first!
This is a perfectly way to calculate JVPs, of course, but it's a little worrisome for a debugging utility to change how compute happens. It's all the more worrisome because JVP are implemented with tracers, which I would not expect to change the order of function evaluation. I can imagine this resulting in some very frustrating debugging sessions, e.g., if code crashes only during the tangent calculation.
The text was updated successfully, but these errors were encountered: