-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cryptic XLA error when using JAX 0.4.26 #21396
Comments
Can you please try with (a) a fresh virtualenv and (b) using |
Hi @hawkinsp, Thank you for the quick response. I followed your instructions but encountered the same error. To provide more context, the error occurs during batched operations with a model written in Equinox. Unfortunately, I have not been able to recreate it with simpler examples, leading me to believe the issue is not on Equinox's side. @patrick-kidger, have you encountered similar issues by any chance? |
I've not seen this one before I'm afraid! |
Can you share an HLO dump? That might be enough for us to reproduce. Run with |
Hmm. I can't reproduce from the HLO dump. I think we'll need a Python-level reproduction. |
Actually, never mind, I can reproduce from the HLO. |
Hi @hawkinsp, I am checking in to see if there are any updates on this issue. Any information would be helpful. Thanks! |
I filed an internal bug for our XLA compiler folks (b/342589917), and I'm waiting for one of them to take a look. I have no additional information other than: yes, I can reproduce the problem from the HLO dump. |
A fix for this was merged (openxla/xla@e7bd8ad). The fix should be in today's jaxlib nightly. Please try it out and let me know if the problem is fixed. |
Hi @hawkinsp,
|
For what it's worth, I've encountrered the same error in jaxlib 0.4.28 when using
The nightly then had a nice new traceback, that allowed me to fix the issue. |
Description
Dear JAX Team,
I am transitioning my code from jax 0.4.25 to jax 0.4.26 because I would like to use it with the latest NGC container and I am running into XLA errors which I cannot quite understand. Unfortunately, I could not reproduce the error with a small piece of code. However, I have observed that the error occurs only when using a batch size greater than one. Here is the error message:
Could you please assist in diagnosing and resolving these issues? Any guidance would be greatly appreciated.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: