-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax 0.4.27: ValueError: safe_map() argument 2 is shorter than argument 1 #716
Comments
Thanks for the report! I've reproduced this as an upstream bug in JAX: jax-ml/jax#21116 |
Thank you for the quick response !
|
Finally, a last test broke for me:
|
Ah, looks like a bit of other work might be required too. These other failures at least don't look too scary: I think they're just small perturbations from where things were before, and that we can/should adjust in Equinox. |
With jax 0.4.28, the
in the following tests:
|
Okay, I think things should be fixed with #719. I'll aim to do a new release shortly. |
Cool ! I will test that as soon as the release is out. |
Failures are tracked upstream at patrick-kidger/equinox#716
All good, thanks ! |
Since jax 0.4.27, several tests fail with:
The text was updated successfully, but these errors were encountered: