Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Update JAX imports #3683

Closed
cringeyburger opened this issue Jan 3, 2024 · 3 comments · Fixed by #3684
Closed

[Bug]: Update JAX imports #3683

cringeyburger opened this issue Jan 3, 2024 · 3 comments · Fixed by #3684
Assignees
Labels
bug Something isn't working difficulty: easy A good issue for someone new. Can be done in a few hours

Comments

@cringeyburger
Copy link
Contributor

PyBaMM Version

Develop

Python Version

3.11.5

Describe the bug

While trying solutions for #3617, I discovered that many JAX imports are deprecated. For some reason, this deprecation causes the pytest not to run the tests correctly and pytest only shows deprecation warnings.

I tried fixing a few imports by changing the source code, but it is better to open an issue.

Steps to Reproduce

  1. From the file test_base_submodel.py, run any test using pytest (I ran test_parameter_info_error)
  2. The code I used to run pytest is as follows:
/home/yukinatsu/miniconda3/bin/conda run -n pybamm --no-capture-output pytest /home/yukinatsu/PyBaMM/tests/unit/test_models/test_submodels/test_base_submodel.py::TestBaseSubModel::test_parameter_info_error
  1. You will encounter in total 3 Deprecation Warnings in total (Next warning will come after you solve the previous one)
  2. The following are the warnings:

TWO WARNINGS OF : DeprecationWarning: Accessing jax.config via the jax.config submodule is deprecated.

ONE WARNING OF : DeprecationWarning: jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.

Concerned files:

  1. jax_bdf_solver.py
  2. evaluate_python.py

I changed the imports as needed and there are no further errors and pytest works as intended. I would be happy to take this one.

Relevant log output

No response

@cringeyburger cringeyburger added the bug Something isn't working label Jan 3, 2024
@agriyakhetarpal
Copy link
Member

Thanks for opening an issue about this @cringeyburger, I'm happy to assign you to it. Looks like I missed some of these deprecation warnings that you have mentioned when I had opened #3644.

@agriyakhetarpal agriyakhetarpal added the difficulty: easy A good issue for someone new. Can be done in a few hours label Jan 3, 2024
@kratman
Copy link
Contributor

kratman commented Jan 3, 2024

@cringeyburger Was this with the latest pybamm? A PR was merged a few hours ago to address the jax_bdf_solver.py warnings (#3671). Just curious if that change was incomplete, if so then feel free to finish it up

@cringeyburger
Copy link
Contributor Author

Yes, I noticed the recent PR merged by @prady0t, but it seems that a few warnings (solved one was for jax.extend, the jax.config were left out) were left out. I will be opening a PR,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working difficulty: easy A good issue for someone new. Can be done in a few hours
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants