Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX BDF solver tests failing / update [jax] versions (due to scipy.linalg.tril deprecation) #3959

Closed
2 tasks done
agriyakhetarpal opened this issue Apr 3, 2024 · 18 comments · Fixed by #4103
Closed
2 tasks done
Assignees
Labels
bug Something isn't working priority: high To be resolved as soon as possible release blocker Issues that need to be addressed before the creation of a release

Comments

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented Apr 3, 2024

The JAX BDF solver tests are failing on all PRs (#3846, #3945, etc.) for Python 3.9 and later because SciPy removed some linear algebra routines in v1.13.0. The Python 3.8 tests are passing because SciPy has dropped support for it earlier

I'm guessing we need to bump the jax and jaxlib versions now or relax the pin in the requirements, because there have been quite many releases since v0.4.20 – the current version available at the time of writing is v0.4.25.

Checklist

@agriyakhetarpal agriyakhetarpal changed the title JAX BDF solver tests failing / update [jax] versions (due to scipy.linalg.tril deprecation) JAX BDF solver tests failing / update [jax] versions (due to scipy.linalg.tril deprecation) Apr 3, 2024
@agriyakhetarpal agriyakhetarpal self-assigned this Apr 3, 2024
@agriyakhetarpal agriyakhetarpal added bug Something isn't working priority: high To be resolved as soon as possible labels Apr 3, 2024
@agriyakhetarpal
Copy link
Member Author

It's probably not as trivial as bumping the JAX version because there are a few other errors that I don't understand with JAX's JIT and spectral volumes, so I'm putting this aside for a bit to return to soon and let others proceed if there is progress

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented Apr 3, 2024

Bumping to v0.4.24 fixes at least part of the tests, earlier versions still have the SciPy error

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

It is worthwhile to bump jax up as high as possible. We have people that are experienced with Jax that might be able to help. We are going to get into more compatibility issues as the code ages

@agriyakhetarpal
Copy link
Member Author

I agree with you – v0.4.26 is their latest release, should we drop the pin altogether? It might break on v0.5.X, so having >0.4, <0.5 bounds is another option

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

Pinning is fine so there are not unexpected changes. Realistically we should have all major dependencies pinned. Something like dependabot should do the updates so the failures are all in one place

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

Do you need help with this one?

@valentinsulzer
Copy link
Member

we should have all major dependencies pinned

We shouldn't pin to exact versions as that may cause compatibility issues for our users (if they try to use pybamm + another package that happens to pin e.g. numpy to a different version). We can specify ranges but they should be as wide as possible

@valentinsulzer
Copy link
Member

jax is an exception where we have to pin the exact version since every release changes the API

@agriyakhetarpal
Copy link
Member Author

Do you need help with this one?

I would appreciate that, being someone who hasn't used JAX a lot. I was able to get the tests to pass with newer versions of JAX (some of those can be ignored because it's probably not caching the solves properly on my machine). Some spatial methods tests are still failing, where I received IndexErrors – and my debugger doesn't help there

We can specify ranges but they should be as wide as possible

To add to this, we have been keeping the lower bounds in sync with the versions of the packages available on conda-forge (too much of a lower bound brought some trouble earlier during the time of the PyBaMM 23.9 release). It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?

I was planning on putting up a PR for that this week. Seemed to align with the removal of ODEs and the removal of the Jax windows restrictions. I will probably just go ahead and make that PR while helping with the Jax stuff. I should have a bit of time to take a look this afternoon. Just share the branch you are working on and I will see what I can do to help out

@agriyakhetarpal
Copy link
Member Author

I don't have a branch or anything concrete, I was debugging only locally. I'll add the link here once I get back to it

@valentinsulzer
Copy link
Member

Yeah let's follow numpy's lead for which python versions we support, they have dropped support for 3.8

@kratman kratman mentioned this issue Apr 3, 2024
5 tasks
@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

A few related issues were solved with #3963, #3961, and #3962. I will take another stab at updating Jax in a few days

@agriyakhetarpal agriyakhetarpal added the release blocker Issues that need to be addressed before the creation of a release label Apr 25, 2024
@brosaplanella
Copy link
Member

I was checking this issue, hasn't this been solved by the PRs Eric referenced above? When I looked at the CI tests seem to be passing.

@agriyakhetarpal
Copy link
Member Author

Ah, that is still one part of the issue. The other thing is that we still need to unpin SciPy which is currently set to <1.13.0, IIRC.

@brosaplanella
Copy link
Member

So, if we unpin SciPy then tests fail, right? Is there any branch where this is done so I can see the errors?

@agriyakhetarpal
Copy link
Member Author

Yes. I tried only locally last time and I was just going to open up a PR to show you the logs, but I'm facing a strange error locally right now:

nox > python run-tests.py --unit
nox > Command python run-tests.py --unit failed with exit code -9
nox > Session unit failed.

and zsh kills my shell for some reason. I think this is something because of #4092 that we merged a while ago which wasn't caught in CI for either of the architectures. Maybe this is related to the fact that I upgraded my macOS version a few hours ago.

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented May 20, 2024

Edit: I see that you opened a PR just at the time I commented :)

brosaplanella added a commit that referenced this issue May 20, 2024
brosaplanella added a commit that referenced this issue May 20, 2024
brosaplanella added a commit that referenced this issue May 21, 2024
brosaplanella added a commit that referenced this issue May 21, 2024
brosaplanella added a commit that referenced this issue May 21, 2024
* #3959 updated scipy and jax versions to fix deprecation error

* #3959 fix issue with vstack array dimensions

* style: pre-commit fixes

* #3959 use direct solver for interpolant

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <[email protected]>

* #3959 use jax and jaxlib 0.4.27

* #3959 revert to iterative solver for interpolator and relax test tolerances

* style: pre-commit fixes

* ruff

* #3959 Eric's comments

* #3959 reduce tolerances to fix macos-14 unit tests

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <[email protected]>

* #3959 relax some more tolerances

* #3959 reduce solve time in test to avoid overdischarge (aiming to fix macos failing test)

* #3959 extend solving time to reach end of discharge

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <[email protected]>
Co-authored-by: Eric G. Kratz <[email protected]>
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
…m-team#4103)

* pybamm-team#3959 updated scipy and jax versions to fix deprecation error

* pybamm-team#3959 fix issue with vstack array dimensions

* style: pre-commit fixes

* pybamm-team#3959 use direct solver for interpolant

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <[email protected]>

* pybamm-team#3959 use jax and jaxlib 0.4.27

* pybamm-team#3959 revert to iterative solver for interpolator and relax test tolerances

* style: pre-commit fixes

* ruff

* pybamm-team#3959 Eric's comments

* pybamm-team#3959 reduce tolerances to fix macos-14 unit tests

* Update pyproject.toml

Co-authored-by: Agriya Khetarpal <[email protected]>

* pybamm-team#3959 relax some more tolerances

* pybamm-team#3959 reduce solve time in test to avoid overdischarge (aiming to fix macos failing test)

* pybamm-team#3959 extend solving time to reach end of discharge

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <[email protected]>
Co-authored-by: Eric G. Kratz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: high To be resolved as soon as possible release blocker Issues that need to be addressed before the creation of a release
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants