Skip to content

Commit

Permalink
fix: replace deprecated jax.numpy.trapz
Browse files Browse the repository at this point in the history
Apparently in the 0.4.16 JAX release, there were several deprecations
following NEP52
(https://numpy.org/neps/nep-0052-python-api-cleanup.html)
One of those was jax.numpy.trapz. Instead we are meant to use
jax.scipy.integrate.trapozoid.

The C. Elegans code used trapz so it is not running on the current
version.
  • Loading branch information
alonfnt committed Mar 14, 2024
1 parent 80c1c06 commit 75bddf5
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions celegans/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax
import jax.numpy as jnp
from jax.scipy.integrate import trapezoid


def _theta(t, s, params):
Expand Down Expand Up @@ -123,9 +124,9 @@ def solve(t, u, X, ds, alpha):
fx = Ut * tx[jnp.newaxis] + alpha * Un * nx[jnp.newaxis]
fy = Ut * ty[jnp.newaxis] + alpha * Un * ny[jnp.newaxis]

Fx = jnp.trapz(fx, dx=ds)
Fy = jnp.trapz(fy, dx=ds)
Tau = jnp.trapz(x * fy - y * fx, dx=ds)
Fx = trapezoid(fx, dx=ds)
Fy = trapezoid(fy, dx=ds)
Tau = trapezoid(x * fy - y * fx, dx=ds)

b = -jnp.array([Fx[0], Fy[0], Tau[0]])
A = jnp.array([Fx[1:], Fy[1:], Tau[1:]])
Expand Down

0 comments on commit 75bddf5

Please sign in to comment.