Skip to content

Commit

Permalink
Fix Jax links (#4504)
Browse files Browse the repository at this point in the history
* Fix lychee

* Fix other jax links
  • Loading branch information
kratman authored Oct 11, 2024
1 parent 974b10a commit 3bf3ea8
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ coverage.xml
htmlcov/

# virtual environment
.venv
env/
venv/
venv3.5/
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/CITATIONS.bib
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ @article{Hindmarsh2005
@misc{jax2018,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and Skye Wanderman-Milne},
title = {{JAX: composable transformations of Python+NumPy programs}},
url = {http://github.com/google/jax},
url = {http://github.com/jax-ml/jax},
version = {0.2.5},
year = {2018},
}
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
MIN_FACTOR = 0.2
MAX_FACTOR = 10

# https://github.com/google/jax/issues/4572#issuecomment-709809897
# https://github.com/jax-ml/jax/issues/4572#issuecomment-709809897
def some_hash_function(x):
return hash(str(x))

Expand Down Expand Up @@ -711,7 +711,7 @@ def block_fun(i, j, Ai, Aj):
return onp.block(blocks)

# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor
# edits), has been modified from the JAX library at https://github.com/google/jax.
# edits), has been modified from the JAX library at https://github.com/jax-ml/jax.
# The main difference is the addition of support for semi-explicit dae index 1
# problems via the addition of a mass matrix.
# This is under an Apache license, a short form of which is given here:
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class JaxSolver(pybamm.BaseSolver):
extra_options : dict, optional
Any options to pass to the solver.
Please consult `JAX documentation
<https://github.com/google/jax/blob/master/jax/experimental/ode.py>`_
<https://github.com/jax-ml/jax/blob/master/jax/experimental/ode.py>`_
for details.
"""

Expand Down Expand Up @@ -263,8 +263,8 @@ async def solve_model_async(inputs_v):
# sparse matrix support in JAX resulting in high memory usage, or a bug
# in the BDF solver.
#
# This issue on guthub appears related:
# https://github.com/google/jax/discussions/13930
# This issue on GitHub appears related:
# https://github.com/jax-ml/jax/discussions/13930
#
# # Split input list based on the number of available xla devices
# device_count = jax.local_device_count()
Expand Down

0 comments on commit 3bf3ea8

Please sign in to comment.