Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autograd -> JAX #118

Open
ryancoe opened this issue Jul 1, 2022 · 5 comments · May be fixed by #319
Open

autograd -> JAX #118

ryancoe opened this issue Jul 1, 2022 · 5 comments · May be fixed by #319
Labels
enhancement New feature or request upstream Related to a dependency of our package

Comments

@ryancoe
Copy link
Collaborator

ryancoe commented Jul 1, 2022

Ever since originally adopting autograd, we've been concerned that most of the development energy from autograd has moved to JAX. In addition to continued development, JAX also has more complete functionality (e.g., for fft).

We did not use JAX initially because support for MS Windows is not great - users can either compile on their own, use a third-party binary, or use Windows Subsystem for Linux (WSL) (https://github.com/google/jax#installation).

Given that more direct JAX support for MS Windows does not imminent, we there are two major hurdles preventing us from transitioning:

  • Additional challenge for MS Windows users - need to see how much of a burden it is to use Windows Subsystem for Linux
  • Changes to our source - we initially thought this would amount to changing import autograd.numpy as np to import jax.numpy as np, but apparently there is more to it than that
@ryancoe ryancoe added the upstream Related to a dependency of our package label Jul 1, 2022
@cmichelenstrofer cmichelenstrofer added the enhancement New feature or request label Oct 17, 2022
@michaelcdevin
Copy link
Collaborator

@ryancoe @cmichelenstrofer it appears a Windows-compatible pip install is now available for JAX as of v0.4.13. That's one of the major implementation hurdles out of the way!

@cmichelenstrofer
Copy link
Member

We should try it at some point. But it does require changing the source a bit, so it won't be a small task.

@ryancoe
Copy link
Collaborator Author

ryancoe commented Oct 10, 2023

@michaelcdevin - can you quickly see how well the Jax Windows install works as a first step?

@michaelcdevin
Copy link
Collaborator

pip install jax works without a hitch on Windows. I tested some of the basic jax.numpy and jax.grad operations and all seems to work as expected.

@jorgeypcb jorgeypcb linked a pull request Mar 12, 2024 that will close this issue
9 tasks
@michaelcdevin
Copy link
Collaborator

NumPy v2.0 was released four days ago with various breaking changes. Since autograd is no longer maintained, this makes it so autograd is incompatible with current and future releases of NumPy.

It looks like jax was proactive about maintaining compatibility with NumPy v2.0. I added in a NumPy version limitation in 902c17b as a stopgap so WecOptTool doesn't break, but switching from autograd to jax is a higher priority now so we don't fall behind in NumPy versions.

@cmichelenstrofer @ryancoe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request upstream Related to a dependency of our package
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants