From 789b7533ce177400702b868b41573a382e9c8bf2 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sun, 3 Mar 2024 18:10:41 -0500 Subject: [PATCH 1/4] Use lineax for tridiagonal solve --- interpax/_spline.py | 30 ++++++++++++------------------ requirements-dev.txt | 1 + requirements.txt | 2 +- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 1b4a8a7..e949323 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -7,6 +7,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import lineax as lx import numpy as np from jax import jit @@ -1228,22 +1229,14 @@ def approx_df( dxi = jnp.where(dx == 0, 0, 1 / dx) df = dxi * df - A = jnp.diag( - jnp.concatenate( - ( - np.array([1.0]), - 2 * (dx.flatten()[:-1] + dx.flatten()[1:]), - np.array([1.0]), - ) - ) - ) - upper_diag1 = jnp.diag( - jnp.concatenate((np.array([1.0]), dx.flatten()[:-1])), k=1 - ) - lower_diag1 = jnp.diag( - jnp.concatenate((dx.flatten()[1:], np.array([1.0]))), k=-1 - ) - A += upper_diag1 + lower_diag1 + one = jnp.array([1.0]) + dxflat = dx.flatten() + diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one]) + upper_diag = jnp.concatenate([one, dxflat[:-1]]) + lower_diag = jnp.concatenate([dxflat[1:], one]) + + A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) + b = jnp.concatenate( [ 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), @@ -1260,8 +1253,9 @@ def approx_df( ) ba = jnp.moveaxis(b, axis, 0) br = ba.reshape((b.shape[axis], -1)) - fx = jnp.linalg.solve(A, br).reshape(ba.shape) - fx = jnp.moveaxis(fx, 0, axis) + solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value + fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T + fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) return fx elif method in ["cardinal", "catmull-rom"]: diff --git a/requirements-dev.txt b/requirements-dev.txt index daebeb8..99098ad 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ -r ./requirements.txt jax[cpu] >= 0.3.2, <= 0.5.0 +scipy >= 1.5.0, < 2.0 # building the docs sphinx > 3.0.0 diff --git a/requirements.txt b/requirements.txt index 5842a7e..9fa60ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ equinox jax >= 0.3.2, <= 0.5.0 +lineax numpy >= 1.20.0, < 2.0 -scipy >= 1.5.0, < 2.0 From c2c273bf52bf715bdc8d2ece8cb2bb484e604f26 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sun, 3 Mar 2024 18:49:12 -0500 Subject: [PATCH 2/4] Move stuff around and reduce testing time --- interpax/__init__.py | 2 +- interpax/_coefs.py | 161 ++++++++++++++++++ interpax/_fd_derivs.py | 206 ++++++++++++++++++++++ interpax/_spline.py | 349 +------------------------------------- tests/test_interpolate.py | 110 ++---------- 5 files changed, 382 insertions(+), 446 deletions(-) create mode 100644 interpax/_coefs.py create mode 100644 interpax/_fd_derivs.py diff --git a/interpax/__init__.py b/interpax/__init__.py index a2dc73a..539bdb5 100644 --- a/interpax/__init__.py +++ b/interpax/__init__.py @@ -1,12 +1,12 @@ """interpax: interpolation and function approximation with JAX.""" from . import _version +from ._fd_derivs import approx_df from ._fourier import fft_interp1d, fft_interp2d from ._spline import ( Interpolator1D, Interpolator2D, Interpolator3D, - approx_df, interp1d, interp2d, interp3d, diff --git a/interpax/_coefs.py b/interpax/_coefs.py new file mode 100644 index 0000000..2ba6938 --- /dev/null +++ b/interpax/_coefs.py @@ -0,0 +1,161 @@ +"""Matrices/tensors for getting spline coefficients from derivatives.""" + +import numpy as np + +# fmt: off +A_TRICUBIC = np.array([ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 9,-9,-9, 9, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 4, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 6,-6, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2,-2,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 6,-6, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2,-1,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 4,-4,-4, 4, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9,-9, 9, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 4, 2, 2, 1, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0,-2,-2,-1,-1, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0,-2,-1,-2,-1, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4,-4, 4, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], # noqa: E501 + [-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 9,-9, 0, 0,-9, 9, 0, 0, 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 0, 0, 6,-6, 0, 0,-3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9, 0, 0,-9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0], # noqa: E501 + [ 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0], # noqa: E501 + [-27,27,27,-27,27,-27,-27,27,-18,-9,18, 9,18, 9,-18,-9,-18,18,-9, 9,18,-18, 9,-9,-18,18,18,-18,-9, 9, 9, # noqa: E501 + -9,-12,-6,-6,-3,12, 6, 6, 3,-12,-6,12, 6,-6,-3, 6, 3,-12,12,-6, 6,-6, 6,-3, 3,-8,-4,-4,-2,-4,-2,-2,-1], # noqa: E501 + [18,-18,-18,18,-18,18,18,-18, 9, 9,-9,-9,-9,-9, 9, 9,12,-12, 6,-6,-12,12,-6, 6,12,-12,-12,12, 6,-6,-6, # noqa: E501 + 6, 6, 6, 3, 3,-6,-6,-3,-3, 6, 6,-6,-6, 3, 3,-3,-3, 8,-8, 4,-4, 4,-4, 2,-2, 4, 4, 2, 2, 2, 2, 1, 1], # noqa: E501 + [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0], # noqa: E501 + [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6, 9,-9, 9,-9,-9, 9,-9, 9,12,-12,-12,12, 6,-6,-6, # noqa: E501 + 6, 6, 3, 6, 3,-6,-3,-6,-3, 8, 4,-8,-4, 4, 2,-4,-2, 6,-6, 6,-6, 3,-3, 3,-3, 4, 2, 4, 2, 2, 1, 2, 1], # noqa: E501 + [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-6, 6,-6, 6, 6,-6, 6,-6,-8, 8, 8,-8,-4, 4, 4,-4, # noqa: E501 + -3,-3,-3,-3, 3, 3, 3, 3,-4,-4, 4, 4,-2,-2, 2, 2,-4, 4,-4, 4,-2, 2,-2, 2,-2,-2,-2,-2,-1,-1,-1,-1], # noqa: E501 + [ 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 0, 0, 6,-6, 0, 0,-4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 4,-4, 0, 0,-4, 4, 0, 0, 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4, 0, 0,-4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], # noqa: E501 + [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0], # noqa: E501 + [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6,12,-12, 6,-6,-12,12,-6, 6, 9,-9,-9, 9, 9,-9,-9, # noqa: E501 + 9, 8, 4, 4, 2,-8,-4,-4,-2, 6, 3,-6,-3, 6, 3,-6,-3, 6,-6, 3,-3, 6,-6, 3,-3, 4, 2, 2, 1, 4, 2, 2, 1], # noqa: E501 + [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-8, 8,-4, 4, 8,-8, 4,-4,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 + -4,-4,-2,-2, 4, 4, 2, 2,-3,-3, 3, 3,-3,-3, 3, 3,-4, 4,-2, 2,-4, 4,-2, 2,-2,-2,-1,-1,-2,-2,-1,-1], # noqa: E501 + [ 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], # noqa: E501 + [-12,12,12,-12,12,-12,-12,12,-8,-4, 8, 4, 8, 4,-8,-4,-6, 6,-6, 6, 6,-6, 6,-6,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 + -4,-2,-4,-2, 4, 2, 4, 2,-4,-2, 4, 2,-4,-2, 4, 2,-3, 3,-3, 3,-3, 3,-3, 3,-2,-1,-2,-1,-2,-1,-2,-1], # noqa: E501 + [ 8,-8,-8, 8,-8, 8, 8,-8, 4, 4,-4,-4,-4,-4, 4, 4, 4,-4, 4,-4,-4, 4,-4, 4, 4,-4,-4, 4, 4,-4,-4, 4, # noqa: E501 + 2, 2, 2, 2,-2,-2,-2,-2, 2, 2,-2,-2, 2, 2,-2,-2, 2,-2, 2,-2, 2,-2, 2,-2, 1, 1, 1, 1, 1, 1, 1, 1] # noqa: E501 +]) + +A_BICUBIC = np.array([ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [-3, 3, 0, 0, -2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [2, -2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, -2, -1, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 1, 1, 0, 0 ], + [-3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, -3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0 ], + [9, -9, -9, 9, 6, 3, -6, -3, 6, -6, 3, -3, 4, 2, 2, 1 ], + [-6, 6, 6, -6, -3, -3, 3, 3, -4, 4, -2, 2, -2, -2, -1, -1 ], + [2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0 ], + [-6, 6, 6, -6, -4, -2, 4, 2, -3, 3, -3, 3, -2, -1, -2, -1 ], + [4, -4, -4, 4, 2, 2, -2, -2, 2, -2, 2, -2, 1, 1, 1, 1] +]) + +A_CUBIC = np.array([ + [1, 0, 0, 0], + [0, 0, 1, 0], + [-3, 3, -2, -1], + [2, -2, 1, 1], +]) diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py new file mode 100644 index 0000000..eb7e4fe --- /dev/null +++ b/interpax/_fd_derivs.py @@ -0,0 +1,206 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import lineax as lx +from jax import jit + + +@partial(jit, static_argnames=("method", "axis")) +def approx_df( + x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs +): + """Approximates first derivatives using cubic spline interpolation. + + Parameters + ---------- + x : ndarray, shape(Nx,) + coordinates of known function values ("knots") + f : ndarray + Known function values. Should have length ``Nx`` along axis=axis + method : str + method of approximation + + - ``'cubic'``: C1 cubic splines (aka local splines) + - ``'cubic2'``: C2 cubic splines (aka natural splines) + - ``'catmull-rom'``: C1 cubic centripetal "tension" splines + - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass + keyword parameter ``c`` in float[0,1] to specify tension + - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the + data, and will not introduce new extrema in the interpolated points + - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at + both endpoints + + axis : int + Axis along which f is varying. + + Returns + ------- + df : ndarray, shape(f.shape) + First derivative of f with respect to x. + + """ + if method == "cubic": + out = _cubic1(x, f, axis, **kwargs) + elif method == "cubic2": + out = _cubic2(x, f, axis) + elif method == "cardinal": + out = _cardinal(x, f, axis, **kwargs) + elif method == "catmull-rom": + out = _cardinal(x, f, axis, **kwargs) + elif method == "monotonic": + out = _monotonic(x, f, axis, False, **kwargs) + elif method == "monotonic-0": + out = _monotonic(x, f, axis, True, **kwargs) + elif method in ("nearest", "linear"): + out = jnp.zeros_like(f) + else: + raise ValueError(f"got unknown method {method}") + return out + + +def _cubic1(x, f, axis): + dx = jnp.diff(x) + df = jnp.diff(f, axis=axis) + dxi = jnp.where(dx == 0, 0, 1 / dx) + if df.ndim > dxi.ndim: + dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) + dxi = jnp.moveaxis(dxi, 0, axis) + df = dxi * df + fx = jnp.concatenate( + [ + jnp.take(df, jnp.array([0]), axis, mode="wrap"), + 1 + / 2 + * ( + jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") + + jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") + ), + jnp.take(df, jnp.array([-1]), axis, mode="wrap"), + ], + axis=axis, + ) + return fx + + +def _cubic2(x, f, axis): + dx = jnp.diff(x) + df = jnp.diff(f, axis=axis) + if df.ndim > dx.ndim: + dx = jnp.expand_dims(dx, tuple(range(1, df.ndim))) + dx = jnp.moveaxis(dx, 0, axis) + dxi = jnp.where(dx == 0, 0, 1 / dx) + df = dxi * df + + one = jnp.array([1.0]) + dxflat = dx.flatten() + diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one]) + upper_diag = jnp.concatenate([one, dxflat[:-1]]) + lower_diag = jnp.concatenate([dxflat[1:], one]) + + A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) + + b = jnp.concatenate( + [ + 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), + 3 + * ( + jnp.take(dx, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") + * jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") + + jnp.take(dx, jnp.arange(1, df.shape[axis]), axis, mode="wrap") + * jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") + ), + 2 * jnp.take(df, jnp.array([-1]), axis, mode="wrap"), + ], + axis=axis, + ) + ba = jnp.moveaxis(b, axis, 0) + br = ba.reshape((b.shape[axis], -1)) + solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value + fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T + fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) + return fx + + +def _cardinal(x, f, axis, c=0): + dx = x[2:] - x[:-2] + df = jnp.take(f, jnp.arange(2, f.shape[axis]), axis, mode="wrap") - jnp.take( + f, jnp.arange(0, f.shape[axis] - 2), axis, mode="wrap" + ) + dxi = jnp.where(dx == 0, 0, 1 / dx) + if df.ndim > dxi.ndim: + dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) + dxi = jnp.moveaxis(dxi, 0, axis) + df = dxi * df + fx0 = jnp.take(f, jnp.array([1]), axis, mode="wrap") - jnp.take( + f, jnp.array([0]), axis, mode="wrap" + ) + fx0 *= jnp.where(x[0] == x[1], 0, 1 / (x[1] - x[0])) + fx1 = jnp.take(f, jnp.array([-1]), axis, mode="wrap") - jnp.take( + f, jnp.array([-2]), axis, mode="wrap" + ) + fx1 *= jnp.where(x[-1] == x[-2], 0, 1 / (x[-1] - x[-2])) + + fx = (1 - c) * jnp.concatenate([fx0, df, fx1], axis=axis) + return fx + + +def _monotonic(x, f, axis, zero_slope): + f = jnp.moveaxis(f, axis, 0) + fshp = f.shape + if f.ndim == 1: + # So that _edge_case doesn't end up assigning to scalars + x = x[:, None] + f = f[:, None] + hk = x[1:] - x[:-1] + df = jnp.diff(f, axis=axis) + hki = jnp.where(hk == 0, 0, 1 / hk) + if df.ndim > hki.ndim: + hki = jnp.expand_dims(hki, tuple(range(1, df.ndim))) + hki = jnp.moveaxis(hki, 0, axis) + + mk = hki * df + + smk = jnp.sign(mk) + condition = (smk[1:, :] != smk[:-1, :]) | (mk[1:, :] == 0) | (mk[:-1, :] == 0) + + w1 = 2 * hk[1:] + hk[:-1] + w2 = hk[1:] + 2 * hk[:-1] + + if df.ndim > w1.ndim: + w1 = jnp.expand_dims(w1, tuple(range(1, df.ndim))) + w1 = jnp.moveaxis(w1, 0, axis) + w2 = jnp.expand_dims(w2, tuple(range(1, df.ndim))) + w2 = jnp.moveaxis(w2, 0, axis) + + whmean = (w1 / mk[:-1, :] + w2 / mk[1:, :]) / (w1 + w2) + + dk = jnp.where(condition, 0, 1.0 / whmean) + + if zero_slope: + d0 = jnp.zeros((1, dk.shape[1])) + d1 = jnp.zeros((1, dk.shape[1])) + + else: + # special case endpoints, as suggested in + # Cleve Moler, Numerical Computing with MATLAB, Chap 3.6 (pchiptx.m) + def _edge_case(h0, h1, m0, m1): + # one-sided three-point estimate for the derivative + d = ((2 * h0 + h1) * m0 - h0 * m1) / (h0 + h1) + + # try to preserve shape + mask = jnp.sign(d) != jnp.sign(m0) + mask2 = (jnp.sign(m0) != jnp.sign(m1)) & (jnp.abs(d) > 3.0 * jnp.abs(m0)) + mmm = (~mask) & mask2 + + d = jnp.where(mask, 0.0, d) + d = jnp.where(mmm, 3.0 * m0, d) + return d + + hk = 1 / hki + d0 = _edge_case(hk[0, :], hk[1, :], mk[0, :], mk[1, :])[None] + d1 = _edge_case(hk[-1, :], hk[-2, :], mk[-1, :], mk[-2, :])[None] + + dk = jnp.concatenate([d0, dk, d1]) + dk = dk.reshape(fshp) + return dk.reshape(fshp) diff --git a/interpax/_spline.py b/interpax/_spline.py index e949323..6fe10f4 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -7,10 +7,11 @@ import equinox as eqx import jax import jax.numpy as jnp -import lineax as lx import numpy as np from jax import jit +from ._coefs import A_BICUBIC, A_CUBIC, A_TRICUBIC +from ._fd_derivs import approx_df from .utils import errorif, isbool CUBIC_METHODS = ("cubic", "cubic2", "cardinal", "catmull-rom") @@ -1161,349 +1162,3 @@ def noclip(fq, *_): ) return fq - - -@partial(jit, static_argnames=("method", "axis")) -def approx_df( - x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs -): - """Approximates first derivatives using cubic spline interpolation. - - Parameters - ---------- - x : ndarray, shape(Nx,) - coordinates of known function values ("knots") - f : ndarray - Known function values. Should have length ``Nx`` along axis=axis - method : str - method of approximation - - - ``'cubic'``: C1 cubic splines (aka local splines) - - ``'cubic2'``: C2 cubic splines (aka natural splines) - - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass - keyword parameter ``c`` in float[0,1] to specify tension - - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the - data, and will not introduce new extrema in the interpolated points - - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at - both endpoints - - axis : int - Axis along which f is varying. - - Returns - ------- - df : ndarray, shape(f.shape) - First derivative of f with respect to x. - - """ - if method == "cubic": - dx = jnp.diff(x) - df = jnp.diff(f, axis=axis) - dxi = jnp.where(dx == 0, 0, 1 / dx) - if df.ndim > dxi.ndim: - dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) - dxi = jnp.moveaxis(dxi, 0, axis) - df = dxi * df - fx = jnp.concatenate( - [ - jnp.take(df, jnp.array([0]), axis, mode="wrap"), - 1 - / 2 - * ( - jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - + jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - ), - jnp.take(df, jnp.array([-1]), axis, mode="wrap"), - ], - axis=axis, - ) - return fx - - elif method == "cubic2": - dx = jnp.diff(x) - df = jnp.diff(f, axis=axis) - if df.ndim > dx.ndim: - dx = jnp.expand_dims(dx, tuple(range(1, df.ndim))) - dx = jnp.moveaxis(dx, 0, axis) - dxi = jnp.where(dx == 0, 0, 1 / dx) - df = dxi * df - - one = jnp.array([1.0]) - dxflat = dx.flatten() - diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one]) - upper_diag = jnp.concatenate([one, dxflat[:-1]]) - lower_diag = jnp.concatenate([dxflat[1:], one]) - - A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) - - b = jnp.concatenate( - [ - 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), - 3 - * ( - jnp.take(dx, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - * jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - + jnp.take(dx, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - * jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - ), - 2 * jnp.take(df, jnp.array([-1]), axis, mode="wrap"), - ], - axis=axis, - ) - ba = jnp.moveaxis(b, axis, 0) - br = ba.reshape((b.shape[axis], -1)) - solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value - fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T - fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) - return fx - - elif method in ["cardinal", "catmull-rom"]: - dx = x[2:] - x[:-2] - df = jnp.take(f, jnp.arange(2, f.shape[axis]), axis, mode="wrap") - jnp.take( - f, jnp.arange(0, f.shape[axis] - 2), axis, mode="wrap" - ) - dxi = jnp.where(dx == 0, 0, 1 / dx) - if df.ndim > dxi.ndim: - dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) - dxi = jnp.moveaxis(dxi, 0, axis) - df = dxi * df - fx0 = jnp.take(f, jnp.array([1]), axis, mode="wrap") - jnp.take( - f, jnp.array([0]), axis, mode="wrap" - ) - fx0 *= jnp.where(x[0] == x[1], 0, 1 / (x[1] - x[0])) - fx1 = jnp.take(f, jnp.array([-1]), axis, mode="wrap") - jnp.take( - f, jnp.array([-2]), axis, mode="wrap" - ) - fx1 *= jnp.where(x[-1] == x[-2], 0, 1 / (x[-1] - x[-2])) - - if method == "cardinal": - c = kwargs.get("c", 0) - else: - c = 0 - fx = (1 - c) * jnp.concatenate([fx0, df, fx1], axis=axis) - return fx - - elif method in ["monotonic", "monotonic-0"]: - f = jnp.moveaxis(f, axis, 0) - fshp = f.shape - if f.ndim == 1: - # So that _edge_case doesn't end up assigning to scalars - x = x[:, None] - f = f[:, None] - hk = x[1:] - x[:-1] - df = jnp.diff(f, axis=axis) - hki = jnp.where(hk == 0, 0, 1 / hk) - if df.ndim > hki.ndim: - hki = jnp.expand_dims(hki, tuple(range(1, df.ndim))) - hki = jnp.moveaxis(hki, 0, axis) - - mk = hki * df - - smk = jnp.sign(mk) - condition = (smk[1:, :] != smk[:-1, :]) | (mk[1:, :] == 0) | (mk[:-1, :] == 0) - - w1 = 2 * hk[1:] + hk[:-1] - w2 = hk[1:] + 2 * hk[:-1] - - if df.ndim > w1.ndim: - w1 = jnp.expand_dims(w1, tuple(range(1, df.ndim))) - w1 = jnp.moveaxis(w1, 0, axis) - w2 = jnp.expand_dims(w2, tuple(range(1, df.ndim))) - w2 = jnp.moveaxis(w2, 0, axis) - - whmean = (w1 / mk[:-1, :] + w2 / mk[1:, :]) / (w1 + w2) - - dk = jnp.where(condition, 0, 1.0 / whmean) - - if method == "monotonic-0": - d0 = jnp.zeros((1, dk.shape[1])) - d1 = jnp.zeros((1, dk.shape[1])) - - else: - # special case endpoints, as suggested in - # Cleve Moler, Numerical Computing with MATLAB, Chap 3.6 (pchiptx.m) - def _edge_case(h0, h1, m0, m1): - # one-sided three-point estimate for the derivative - d = ((2 * h0 + h1) * m0 - h0 * m1) / (h0 + h1) - - # try to preserve shape - mask = jnp.sign(d) != jnp.sign(m0) - mask2 = (jnp.sign(m0) != jnp.sign(m1)) & ( - jnp.abs(d) > 3.0 * jnp.abs(m0) - ) - mmm = (~mask) & mask2 - - d = jnp.where(mask, 0.0, d) - d = jnp.where(mmm, 3.0 * m0, d) - return d - - hk = 1 / hki - d0 = _edge_case(hk[0, :], hk[1, :], mk[0, :], mk[1, :])[None] - d1 = _edge_case(hk[-1, :], hk[-2, :], mk[-1, :], mk[-2, :])[None] - - dk = jnp.concatenate([d0, dk, d1]) - dk = dk.reshape(fshp) - return dk.reshape(fshp) - - else: # method passed in does not use df from this function, just return 0 - return jnp.zeros_like(f) - - -# fmt: off -A_TRICUBIC = np.array([ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 9,-9,-9, 9, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 4, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 6,-6, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2,-2,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 6,-6, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2,-1,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 4,-4,-4, 4, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9,-9, 9, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 4, 2, 2, 1, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0,-2,-2,-1,-1, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0,-2,-1,-2,-1, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4,-4, 4, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], # noqa: E501 - [-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 9,-9, 0, 0,-9, 9, 0, 0, 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 0, 0, 6,-6, 0, 0,-3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9, 0, 0,-9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0], # noqa: E501 - [ 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0], # noqa: E501 - [-27,27,27,-27,27,-27,-27,27,-18,-9,18, 9,18, 9,-18,-9,-18,18,-9, 9,18,-18, 9,-9,-18,18,18,-18,-9, 9, 9, # noqa: E501 - -9,-12,-6,-6,-3,12, 6, 6, 3,-12,-6,12, 6,-6,-3, 6, 3,-12,12,-6, 6,-6, 6,-3, 3,-8,-4,-4,-2,-4,-2,-2,-1], # noqa: E501 - [18,-18,-18,18,-18,18,18,-18, 9, 9,-9,-9,-9,-9, 9, 9,12,-12, 6,-6,-12,12,-6, 6,12,-12,-12,12, 6,-6,-6, # noqa: E501 - 6, 6, 6, 3, 3,-6,-6,-3,-3, 6, 6,-6,-6, 3, 3,-3,-3, 8,-8, 4,-4, 4,-4, 2,-2, 4, 4, 2, 2, 2, 2, 1, 1], # noqa: E501 - [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0], # noqa: E501 - [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6, 9,-9, 9,-9,-9, 9,-9, 9,12,-12,-12,12, 6,-6,-6, # noqa: E501 - 6, 6, 3, 6, 3,-6,-3,-6,-3, 8, 4,-8,-4, 4, 2,-4,-2, 6,-6, 6,-6, 3,-3, 3,-3, 4, 2, 4, 2, 2, 1, 2, 1], # noqa: E501 - [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-6, 6,-6, 6, 6,-6, 6,-6,-8, 8, 8,-8,-4, 4, 4,-4, # noqa: E501 - -3,-3,-3,-3, 3, 3, 3, 3,-4,-4, 4, 4,-2,-2, 2, 2,-4, 4,-4, 4,-2, 2,-2, 2,-2,-2,-2,-2,-1,-1,-1,-1], # noqa: E501 - [ 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 0, 0, 6,-6, 0, 0,-4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 4,-4, 0, 0,-4, 4, 0, 0, 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4, 0, 0,-4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], # noqa: E501 - [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0], # noqa: E501 - [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6,12,-12, 6,-6,-12,12,-6, 6, 9,-9,-9, 9, 9,-9,-9, # noqa: E501 - 9, 8, 4, 4, 2,-8,-4,-4,-2, 6, 3,-6,-3, 6, 3,-6,-3, 6,-6, 3,-3, 6,-6, 3,-3, 4, 2, 2, 1, 4, 2, 2, 1], # noqa: E501 - [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-8, 8,-4, 4, 8,-8, 4,-4,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 - -4,-4,-2,-2, 4, 4, 2, 2,-3,-3, 3, 3,-3,-3, 3, 3,-4, 4,-2, 2,-4, 4,-2, 2,-2,-2,-1,-1,-2,-2,-1,-1], # noqa: E501 - [ 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], # noqa: E501 - [-12,12,12,-12,12,-12,-12,12,-8,-4, 8, 4, 8, 4,-8,-4,-6, 6,-6, 6, 6,-6, 6,-6,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 - -4,-2,-4,-2, 4, 2, 4, 2,-4,-2, 4, 2,-4,-2, 4, 2,-3, 3,-3, 3,-3, 3,-3, 3,-2,-1,-2,-1,-2,-1,-2,-1], # noqa: E501 - [ 8,-8,-8, 8,-8, 8, 8,-8, 4, 4,-4,-4,-4,-4, 4, 4, 4,-4, 4,-4,-4, 4,-4, 4, 4,-4,-4, 4, 4,-4,-4, 4, # noqa: E501 - 2, 2, 2, 2,-2,-2,-2,-2, 2, 2,-2,-2, 2, 2,-2,-2, 2,-2, 2,-2, 2,-2, 2,-2, 1, 1, 1, 1, 1, 1, 1, 1] # noqa: E501 -]) - -A_BICUBIC = np.array([ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [-3, 3, 0, 0, -2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [2, -2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, -2, -1, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 1, 1, 0, 0 ], - [-3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, -3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0 ], - [9, -9, -9, 9, 6, 3, -6, -3, 6, -6, 3, -3, 4, 2, 2, 1 ], - [-6, 6, 6, -6, -3, -3, 3, 3, -4, 4, -2, 2, -2, -2, -1, -1 ], - [2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0 ], - [-6, 6, 6, -6, -4, -2, 4, 2, -3, 3, -3, 3, -2, -1, -2, -1 ], - [4, -4, -4, 4, 2, 2, -2, -2, 2, -2, 2, -2, 1, 1, 1, 1] -]) - -A_CUBIC = np.array([ - [1, 0, 0, 0], - [0, 0, 1, 0], - [-3, 3, -2, -1], - [2, -2, 1, 1], -]) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index 5f101ff..5e2ca53 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -434,41 +434,11 @@ def _finite_difference(self, f, x, eps=1e-8): @pytest.mark.unit def test_ad_interp1d(self): """Test AD of different 1d interpolation methods.""" - xp = np.linspace(0, 2 * np.pi, 100) - x = np.linspace(0, 2 * np.pi, 200) + xp = np.linspace(0, 2 * np.pi, 10) + x = np.linspace(0, 2 * np.pi, 20) f = lambda x: np.sin(x) fp = f(xp) - for method in ["cubic", "cubic2", "cardinal"]: - interp1 = lambda xq: interp1d(xq, xp, fp, method=method) - interp2 = lambda xq: Interpolator1D(xp, fp, method=method)(xq) - - f1 = jnp.vectorize(jax.grad(interp1))(x) - f2 = jnp.vectorize(jax.grad(interp2))(x) - - np.testing.assert_allclose(f1, np.cos(x), rtol=1e-2, atol=1e-2) - np.testing.assert_allclose(f1, f2) - - for method in ["cubic", "cubic2", "cardinal", "monotonic"]: - - interp1 = lambda fp: interp1d(x, xp, fp, method=method) - interp2 = lambda fp: Interpolator1D(xp, fp, method=method)(x) - - jacf1 = jax.jacfwd(interp1)(fp) - jacf2 = jax.jacfwd(interp2)(fp) - - jacr1 = jax.jacrev(interp1)(fp) - jacr2 = jax.jacrev(interp2)(fp) - - jacd1 = self._finite_difference(interp1, fp) - jacd2 = self._finite_difference(interp2, fp) - - np.testing.assert_allclose(jacf1, jacf2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacr1, jacr2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacf1, jacr1, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacf1, jacd1, rtol=1e-6, atol=1e-6) - np.testing.assert_allclose(jacf2, jacd2, rtol=1e-6, atol=1e-6) - for method in ["cubic", "cubic2", "cardinal", "monotonic"]: interp1 = lambda xp: interp1d(x, xp, fp, method=method) @@ -493,40 +463,15 @@ def test_ad_interp1d(self): @pytest.mark.unit def test_ad_interp2d(self): """Test AD of different 2d interpolation methods.""" - xp = np.linspace(0, 4 * np.pi, 40) - yp = np.linspace(0, 2 * np.pi, 40) - y = np.linspace(0, 2 * np.pi, 100) - x = np.linspace(0, 2 * np.pi, 100) + xp = np.linspace(0, 4 * np.pi, 20) + yp = np.linspace(0, 2 * np.pi, 20) + y = np.linspace(0, 2 * np.pi, 30) + x = np.linspace(0, 2 * np.pi, 30) xxp, yyp = np.meshgrid(xp, yp, indexing="ij") f = lambda x, y: np.sin(x) * np.cos(y) fp = f(xxp, yyp) - for method in ["cubic", "cubic2", "cardinal"]: - interp1 = lambda xq, yq: interp2d(xq, yq, xp, yp, fp, method=method) - interp2 = lambda xq, yq: Interpolator2D(xp, yp, fp, method=method)(xq, yq) - - f1 = jnp.vectorize(jax.grad(interp1))(x, y) - f2 = jnp.vectorize(jax.grad(interp2))(x, y) - - np.testing.assert_allclose(f1, np.cos(x) * np.cos(y), rtol=3e-2, atol=3e-2) - np.testing.assert_allclose(f1, f2) - - for method in ["cubic", "cubic2", "cardinal"]: - - interp1 = lambda fp: interp2d(x, y, xp, yp, fp, method=method) - interp2 = lambda fp: Interpolator2D(xp, yp, fp, method=method)(x, y) - - jacf1 = jax.jacfwd(interp1)(fp) - jacf2 = jax.jacfwd(interp2)(fp) - - jacr1 = jax.jacrev(interp1)(fp) - jacr2 = jax.jacrev(interp2)(fp) - - np.testing.assert_allclose(jacf1, jacf2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacr1, jacr2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacf1, jacr1, rtol=1e-14, atol=1e-14) - for method in ["cubic", "cubic2", "cardinal"]: interp1 = lambda xp: interp2d(x, y, xp, yp, fp, method=method) @@ -551,48 +496,17 @@ def test_ad_interp2d(self): @pytest.mark.unit def test_ad_interp3d(self): """Test AD of different 3d interpolation methods.""" - xp = np.linspace(0, np.pi, 20) - yp = np.linspace(0, 2 * np.pi, 30) - zp = np.linspace(0, 1, 10) - x = np.linspace(0, np.pi, 100) - y = np.linspace(0, 2 * np.pi, 100) - z = np.linspace(0, 1, 100) + xp = np.linspace(0, np.pi, 10) + yp = np.linspace(0, 2 * np.pi, 15) + zp = np.linspace(0, 1, 12) + x = np.linspace(0, np.pi, 13) + y = np.linspace(0, 2 * np.pi, 13) + z = np.linspace(0, 1, 13) xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij") f = lambda x, y, z: np.sin(x) * np.cos(y) * z**2 fp = f(xxp, yyp, zzp) - for method in ["cubic", "cubic2", "cardinal"]: - interp1 = lambda xq, yq, zq: interp3d( - xq, yq, zq, xp, yp, zp, fp, method=method - ) - interp2 = lambda xq, yq, zq: Interpolator3D(xp, yp, zp, fp, method=method)( - xq, yq, zq - ) - - f1 = jnp.vectorize(jax.grad(interp1))(x, y, z) - f2 = jnp.vectorize(jax.grad(interp2))(x, y, z) - - np.testing.assert_allclose( - f1, np.cos(x) * np.cos(y) * z**2, rtol=3e-2, atol=3e-2 - ) - np.testing.assert_allclose(f1, f2) - - for method in ["cubic", "cubic2", "cardinal"]: - - interp1 = lambda fp: interp3d(x, y, z, xp, yp, zp, fp, method=method) - interp2 = lambda fp: Interpolator3D(xp, yp, zp, fp, method=method)(x, y, z) - - jacf1 = jax.jacfwd(interp1)(fp) - jacf2 = jax.jacfwd(interp2)(fp) - - jacr1 = jax.jacrev(interp1)(fp) - jacr2 = jax.jacrev(interp2)(fp) - - np.testing.assert_allclose(jacf1, jacf2, rtol=1e-12, atol=1e-12) - np.testing.assert_allclose(jacr1, jacr2, rtol=1e-12, atol=1e-12) - np.testing.assert_allclose(jacf1, jacr1, rtol=1e-12, atol=1e-12) - for method in ["cubic", "cubic2", "cardinal"]: interp1 = lambda xp: interp3d(x, y, z, xp, yp, zp, fp, method=method) From 6fb64959650f8a458e8e94df5785a0ebe78e0c8d Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sun, 3 Mar 2024 19:14:48 -0500 Subject: [PATCH 3/4] Remove testing for python 3.8 --- .github/workflows/scheduled.yml | 2 +- .github/workflows/unittest.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/scheduled.yml b/.github/workflows/scheduled.yml index 9a0fd4f..284cb44 100644 --- a/.github/workflows/scheduled.yml +++ b/.github/workflows/scheduled.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index e8349d0..758d330 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v4 From c2d0e7cbc2eb19bae126f6ab0b535e206e1f39de Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sun, 3 Mar 2024 19:50:31 -0500 Subject: [PATCH 4/4] Separate lineax stuff for now --- interpax/_fd_derivs.py | 6 ++---- requirements.txt | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py index eb7e4fe..a8ba650 100644 --- a/interpax/_fd_derivs.py +++ b/interpax/_fd_derivs.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jnp -import lineax as lx from jax import jit @@ -98,8 +97,7 @@ def _cubic2(x, f, axis): upper_diag = jnp.concatenate([one, dxflat[:-1]]) lower_diag = jnp.concatenate([dxflat[1:], one]) - A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) - + A = jnp.diag(diag) + jnp.diag(upper_diag, k=1) + jnp.diag(lower_diag, k=-1) b = jnp.concatenate( [ 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), @@ -116,7 +114,7 @@ def _cubic2(x, f, axis): ) ba = jnp.moveaxis(b, axis, 0) br = ba.reshape((b.shape[axis], -1)) - solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value + solve = lambda b: jnp.linalg.solve(A, b) fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) return fx diff --git a/requirements.txt b/requirements.txt index 9fa60ca..6272e0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ equinox jax >= 0.3.2, <= 0.5.0 -lineax numpy >= 1.20.0, < 2.0