-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Crash in eval_jaxpr
with 0.4.27
#21116
Comments
This is probably my bug, I'll look into it. I am pretty sure it's happening because the tracer is being converted into an argument instead of embedding it into |
This is because before |
Actually thinking more, tracers should be passed as argument so if you change So maybe the fix should be in diffrax? |
How should that be done?
|
Now it's in Note that we have this assert that disabled but we want to enable it: https://github.com/google/jax/blob/f768cb74b94ab36587a8930be8afe8a34460ca6b/jax/_src/core.py#L208 |
Hmm, I think I'm only see the jaxpr.jaxpr.invars # [Var(id=4790894848):int32[], Var(id=4790895168):int32[]] but not any way of grabbing To approach this a different way: in <=0.4.26, it is the case that for all functions jaxpr = jax.make_jaxpr(g)(*args)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) Which I think is quite a nice invariant, actually! |
The first |
I don't think so. At least right now, I don't see a way to grab the tracer out of the
What am I missing? |
Sorry, I meant that the first Var is the tracer (which would have been in consts before). There is no indication of that being a tracer right now, unless I give you more information from JAX (which is what my question was) and if that would be enough for you to fix diffrax Before:
Now:
|
Right! But I don't need to know which Sorry, I think I'm realising where the confusion is coming from: I'm not calling (Note that Taking a step back from what's needed to solve this one particular bug in the short term: in general, abstract evaluation can return four things: the jaxpr, the output pytree/avals, the effects, and the closed-over values. Right now, these pieces are incompletely mix-and-match'd across the public API. The state of affairs is:
Where an (Side note, credit to Gemini for kindly making this table for me ^^ ) |
Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with. That should have never happened and this assert needs to be enabled: https://github.com/google/jax/blob/f768cb74b94ab36587a8930be8afe8a34460ca6b/jax/_src/core.py#L208 which would make it so.
Eventually, we are going to merge make_jaxpr and eval_shape into |
#21140 should roll it back. We are going to do another release tomorrow, so this should be fixed in 0.4.28 |
That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you. |
Ah, interesting! I like the unifying of this. Is this something where we'll be able to grab the
I have no preference on whether tracers are placed in jaxprs at all. E.g. I would be equally happy to get them via something looking like The actual tracer object is required to perform closure conversion prior to crossing the boundaries of higher order primitives, custom AD etc.
Thank you! I really appreciate this. I think if you're looking for a concrete suggestion on a path forward from me:
|
We just released jax 0.4.28, which has the rollback. |
Do you mean If you like, I could write a quick PR adding a flag to
Thank you! I appreciate it. |
Rolling forward with #21734 and fixed the reported error too. So we should be good to go without any changes on your side. |
Wonderful stuff! Thank you so much :) |
Haha.. no problem. |
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't. Since we can keep the existing behavior and still merge the implementation is a good cleanup! Fixes #21116 PiperOrigin-RevId: 641254235
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't. Since we can keep the existing behavior and still merge the implementation is a good cleanup! Fixes #21116 PiperOrigin-RevId: 641254235
Description
produces:
System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.27
The text was updated successfully, but these errors were encountered: