diff --git a/celegans/simulation.py b/celegans/simulation.py index 2dedde7..b704d60 100644 --- a/celegans/simulation.py +++ b/celegans/simulation.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +from jax.scipy.integrate import trapezoid def _theta(t, s, params): @@ -123,9 +124,9 @@ def solve(t, u, X, ds, alpha): fx = Ut * tx[jnp.newaxis] + alpha * Un * nx[jnp.newaxis] fy = Ut * ty[jnp.newaxis] + alpha * Un * ny[jnp.newaxis] - Fx = jnp.trapz(fx, dx=ds) - Fy = jnp.trapz(fy, dx=ds) - Tau = jnp.trapz(x * fy - y * fx, dx=ds) + Fx = trapezoid(fx, dx=ds) + Fy = trapezoid(fy, dx=ds) + Tau = trapezoid(x * fy - y * fx, dx=ds) b = -jnp.array([Fx[0], Fy[0], Tau[0]]) A = jnp.array([Fx[1:], Fy[1:], Tau[1:]])