From 75bddf515891e50d2b0d312c782199eb90218db2 Mon Sep 17 00:00:00 2001 From: Albert Alonso Date: Thu, 14 Mar 2024 23:09:25 +0100 Subject: [PATCH 1/2] fix: replace deprecated jax.numpy.trapz Apparently in the 0.4.16 JAX release, there were several deprecations following NEP52 (https://numpy.org/neps/nep-0052-python-api-cleanup.html) One of those was jax.numpy.trapz. Instead we are meant to use jax.scipy.integrate.trapozoid. The C. Elegans code used trapz so it is not running on the current version. --- celegans/simulation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/celegans/simulation.py b/celegans/simulation.py index 2dedde7..b704d60 100644 --- a/celegans/simulation.py +++ b/celegans/simulation.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +from jax.scipy.integrate import trapezoid def _theta(t, s, params): @@ -123,9 +124,9 @@ def solve(t, u, X, ds, alpha): fx = Ut * tx[jnp.newaxis] + alpha * Un * nx[jnp.newaxis] fy = Ut * ty[jnp.newaxis] + alpha * Un * ny[jnp.newaxis] - Fx = jnp.trapz(fx, dx=ds) - Fy = jnp.trapz(fy, dx=ds) - Tau = jnp.trapz(x * fy - y * fx, dx=ds) + Fx = trapezoid(fx, dx=ds) + Fy = trapezoid(fy, dx=ds) + Tau = trapezoid(x * fy - y * fx, dx=ds) b = -jnp.array([Fx[0], Fy[0], Tau[0]]) A = jnp.array([Fx[1:], Fy[1:], Tau[1:]]) From 16a6c3859aee847818443d50b26032db02e265c6 Mon Sep 17 00:00:00 2001 From: Albert Alonso Date: Fri, 15 Mar 2024 00:38:07 +0100 Subject: [PATCH 2/2] build: fix packages versions and scikit-version Packages requiere newer Python versions than 3.8. But some packages are old and use deprecated features (scikit-video), so I added a quick fix. Ideally I fix it at some point. Some versions on requirements are also more clearly specified, as there are other deprecations. --- .github/workflows/integration_test.yml | 4 ++-- examples/detect.py | 6 ++++++ examples/track.py | 6 ++++++ requirements.txt | 9 +++++---- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index e0d9905..77ddf1b 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -14,10 +14,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.10' architecture: 'x64' - name: apt-get diff --git a/examples/detect.py b/examples/detect.py index 0be8574..2c85c92 100644 --- a/examples/detect.py +++ b/examples/detect.py @@ -2,6 +2,12 @@ import deeptangle as dt import matplotlib.pyplot as plt from skimage.exposure import equalize_adapthist +import numpy + +# scikit-video uses deprecated numpy.float, numpy.int +# hacky fix: https://github.com/scikit-video/scikit-video/issues/154 +numpy.float = numpy.float64 +numpy.int = numpy.int_ import skvideo.io diff --git a/examples/track.py b/examples/track.py index 5173212..3ae7f88 100644 --- a/examples/track.py +++ b/examples/track.py @@ -9,6 +9,12 @@ import matplotlib.pyplot as plt import numpy as np from skimage.exposure import equalize_adapthist + +# scikit-video uses deprecated numpy.float, numpy.int +# hacky fix: https://github.com/scikit-video/scikit-video/issues/154 +import numpy +numpy.float = numpy.float64 +numpy.int = numpy.int_ import skvideo.io import deeptangle as dt diff --git a/requirements.txt b/requirements.txt index 51dae77..b2bd536 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,11 +3,12 @@ scikit-image scikit-video optax chex -jax +jax>=0.4.16 +jaxlib>=0.4.20 dm-pix scikit-learn -numpy==1.21.6 -numba==0.55 +numpy>=1.21.6 +numba>=0.55 matplotlib -https://github.com/alonfnt/dm-haiku/archive/refs/heads/avg_pool_perf.zip +dm-haiku trackpy