From c2c273bf52bf715bdc8d2ece8cb2bb484e604f26 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sun, 3 Mar 2024 18:49:12 -0500 Subject: [PATCH] Move stuff around and reduce testing time --- interpax/__init__.py | 2 +- interpax/_coefs.py | 161 ++++++++++++++++++ interpax/_fd_derivs.py | 206 ++++++++++++++++++++++ interpax/_spline.py | 349 +------------------------------------- tests/test_interpolate.py | 110 ++---------- 5 files changed, 382 insertions(+), 446 deletions(-) create mode 100644 interpax/_coefs.py create mode 100644 interpax/_fd_derivs.py diff --git a/interpax/__init__.py b/interpax/__init__.py index a2dc73a..539bdb5 100644 --- a/interpax/__init__.py +++ b/interpax/__init__.py @@ -1,12 +1,12 @@ """interpax: interpolation and function approximation with JAX.""" from . import _version +from ._fd_derivs import approx_df from ._fourier import fft_interp1d, fft_interp2d from ._spline import ( Interpolator1D, Interpolator2D, Interpolator3D, - approx_df, interp1d, interp2d, interp3d, diff --git a/interpax/_coefs.py b/interpax/_coefs.py new file mode 100644 index 0000000..2ba6938 --- /dev/null +++ b/interpax/_coefs.py @@ -0,0 +1,161 @@ +"""Matrices/tensors for getting spline coefficients from derivatives.""" + +import numpy as np + +# fmt: off +A_TRICUBIC = np.array([ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 9,-9,-9, 9, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 4, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 6,-6, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2,-2,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 6,-6, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -2,-1,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 4,-4,-4, 4, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9,-9, 9, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 4, 2, 2, 1, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0,-2,-2,-1,-1, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0,-2,-1,-2,-1, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4,-4, 4, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], # noqa: E501 + [-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 9,-9, 0, 0,-9, 9, 0, 0, 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 0, 0, 6,-6, 0, 0,-3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9, 0, 0,-9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0], # noqa: E501 + [ 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0], # noqa: E501 + [-27,27,27,-27,27,-27,-27,27,-18,-9,18, 9,18, 9,-18,-9,-18,18,-9, 9,18,-18, 9,-9,-18,18,18,-18,-9, 9, 9, # noqa: E501 + -9,-12,-6,-6,-3,12, 6, 6, 3,-12,-6,12, 6,-6,-3, 6, 3,-12,12,-6, 6,-6, 6,-3, 3,-8,-4,-4,-2,-4,-2,-2,-1], # noqa: E501 + [18,-18,-18,18,-18,18,18,-18, 9, 9,-9,-9,-9,-9, 9, 9,12,-12, 6,-6,-12,12,-6, 6,12,-12,-12,12, 6,-6,-6, # noqa: E501 + 6, 6, 6, 3, 3,-6,-6,-3,-3, 6, 6,-6,-6, 3, 3,-3,-3, 8,-8, 4,-4, 4,-4, 2,-2, 4, 4, 2, 2, 2, 2, 1, 1], # noqa: E501 + [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0], # noqa: E501 + [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6, 9,-9, 9,-9,-9, 9,-9, 9,12,-12,-12,12, 6,-6,-6, # noqa: E501 + 6, 6, 3, 6, 3,-6,-3,-6,-3, 8, 4,-8,-4, 4, 2,-4,-2, 6,-6, 6,-6, 3,-3, 3,-3, 4, 2, 4, 2, 2, 1, 2, 1], # noqa: E501 + [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-6, 6,-6, 6, 6,-6, 6,-6,-8, 8, 8,-8,-4, 4, 4,-4, # noqa: E501 + -3,-3,-3,-3, 3, 3, 3, 3,-4,-4, 4, 4,-2,-2, 2, 2,-4, 4,-4, 4,-2, 2,-2, 2,-2,-2,-2,-2,-1,-1,-1,-1], # noqa: E501 + [ 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [-6, 6, 0, 0, 6,-6, 0, 0,-4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 4,-4, 0, 0,-4, 4, 0, 0, 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4, 0, 0,-4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], # noqa: E501 + [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + -4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0], # noqa: E501 + [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6,12,-12, 6,-6,-12,12,-6, 6, 9,-9,-9, 9, 9,-9,-9, # noqa: E501 + 9, 8, 4, 4, 2,-8,-4,-4,-2, 6, 3,-6,-3, 6, 3,-6,-3, 6,-6, 3,-3, 6,-6, 3,-3, 4, 2, 2, 1, 4, 2, 2, 1], # noqa: E501 + [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-8, 8,-4, 4, 8,-8, 4,-4,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 + -4,-4,-2,-2, 4, 4, 2, 2,-3,-3, 3, 3,-3,-3, 3, 3,-4, 4,-2, 2,-4, 4,-2, 2,-2,-2,-1,-1,-2,-2,-1,-1], # noqa: E501 + [ 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, # noqa: E501 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 + [ 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 + 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], # noqa: E501 + [-12,12,12,-12,12,-12,-12,12,-8,-4, 8, 4, 8, 4,-8,-4,-6, 6,-6, 6, 6,-6, 6,-6,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 + -4,-2,-4,-2, 4, 2, 4, 2,-4,-2, 4, 2,-4,-2, 4, 2,-3, 3,-3, 3,-3, 3,-3, 3,-2,-1,-2,-1,-2,-1,-2,-1], # noqa: E501 + [ 8,-8,-8, 8,-8, 8, 8,-8, 4, 4,-4,-4,-4,-4, 4, 4, 4,-4, 4,-4,-4, 4,-4, 4, 4,-4,-4, 4, 4,-4,-4, 4, # noqa: E501 + 2, 2, 2, 2,-2,-2,-2,-2, 2, 2,-2,-2, 2, 2,-2,-2, 2,-2, 2,-2, 2,-2, 2,-2, 1, 1, 1, 1, 1, 1, 1, 1] # noqa: E501 +]) + +A_BICUBIC = np.array([ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [-3, 3, 0, 0, -2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [2, -2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, -2, -1, 0, 0 ], + [0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 1, 1, 0, 0 ], + [-3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, -3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0 ], + [9, -9, -9, 9, 6, 3, -6, -3, 6, -6, 3, -3, 4, 2, 2, 1 ], + [-6, 6, 6, -6, -3, -3, 3, 3, -4, 4, -2, 2, -2, -2, -1, -1 ], + [2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0 ], + [0, 0, 0, 0, 2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0 ], + [-6, 6, 6, -6, -4, -2, 4, 2, -3, 3, -3, 3, -2, -1, -2, -1 ], + [4, -4, -4, 4, 2, 2, -2, -2, 2, -2, 2, -2, 1, 1, 1, 1] +]) + +A_CUBIC = np.array([ + [1, 0, 0, 0], + [0, 0, 1, 0], + [-3, 3, -2, -1], + [2, -2, 1, 1], +]) diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py new file mode 100644 index 0000000..eb7e4fe --- /dev/null +++ b/interpax/_fd_derivs.py @@ -0,0 +1,206 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import lineax as lx +from jax import jit + + +@partial(jit, static_argnames=("method", "axis")) +def approx_df( + x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs +): + """Approximates first derivatives using cubic spline interpolation. + + Parameters + ---------- + x : ndarray, shape(Nx,) + coordinates of known function values ("knots") + f : ndarray + Known function values. Should have length ``Nx`` along axis=axis + method : str + method of approximation + + - ``'cubic'``: C1 cubic splines (aka local splines) + - ``'cubic2'``: C2 cubic splines (aka natural splines) + - ``'catmull-rom'``: C1 cubic centripetal "tension" splines + - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass + keyword parameter ``c`` in float[0,1] to specify tension + - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the + data, and will not introduce new extrema in the interpolated points + - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at + both endpoints + + axis : int + Axis along which f is varying. + + Returns + ------- + df : ndarray, shape(f.shape) + First derivative of f with respect to x. + + """ + if method == "cubic": + out = _cubic1(x, f, axis, **kwargs) + elif method == "cubic2": + out = _cubic2(x, f, axis) + elif method == "cardinal": + out = _cardinal(x, f, axis, **kwargs) + elif method == "catmull-rom": + out = _cardinal(x, f, axis, **kwargs) + elif method == "monotonic": + out = _monotonic(x, f, axis, False, **kwargs) + elif method == "monotonic-0": + out = _monotonic(x, f, axis, True, **kwargs) + elif method in ("nearest", "linear"): + out = jnp.zeros_like(f) + else: + raise ValueError(f"got unknown method {method}") + return out + + +def _cubic1(x, f, axis): + dx = jnp.diff(x) + df = jnp.diff(f, axis=axis) + dxi = jnp.where(dx == 0, 0, 1 / dx) + if df.ndim > dxi.ndim: + dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) + dxi = jnp.moveaxis(dxi, 0, axis) + df = dxi * df + fx = jnp.concatenate( + [ + jnp.take(df, jnp.array([0]), axis, mode="wrap"), + 1 + / 2 + * ( + jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") + + jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") + ), + jnp.take(df, jnp.array([-1]), axis, mode="wrap"), + ], + axis=axis, + ) + return fx + + +def _cubic2(x, f, axis): + dx = jnp.diff(x) + df = jnp.diff(f, axis=axis) + if df.ndim > dx.ndim: + dx = jnp.expand_dims(dx, tuple(range(1, df.ndim))) + dx = jnp.moveaxis(dx, 0, axis) + dxi = jnp.where(dx == 0, 0, 1 / dx) + df = dxi * df + + one = jnp.array([1.0]) + dxflat = dx.flatten() + diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one]) + upper_diag = jnp.concatenate([one, dxflat[:-1]]) + lower_diag = jnp.concatenate([dxflat[1:], one]) + + A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) + + b = jnp.concatenate( + [ + 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), + 3 + * ( + jnp.take(dx, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") + * jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") + + jnp.take(dx, jnp.arange(1, df.shape[axis]), axis, mode="wrap") + * jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") + ), + 2 * jnp.take(df, jnp.array([-1]), axis, mode="wrap"), + ], + axis=axis, + ) + ba = jnp.moveaxis(b, axis, 0) + br = ba.reshape((b.shape[axis], -1)) + solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value + fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T + fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) + return fx + + +def _cardinal(x, f, axis, c=0): + dx = x[2:] - x[:-2] + df = jnp.take(f, jnp.arange(2, f.shape[axis]), axis, mode="wrap") - jnp.take( + f, jnp.arange(0, f.shape[axis] - 2), axis, mode="wrap" + ) + dxi = jnp.where(dx == 0, 0, 1 / dx) + if df.ndim > dxi.ndim: + dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) + dxi = jnp.moveaxis(dxi, 0, axis) + df = dxi * df + fx0 = jnp.take(f, jnp.array([1]), axis, mode="wrap") - jnp.take( + f, jnp.array([0]), axis, mode="wrap" + ) + fx0 *= jnp.where(x[0] == x[1], 0, 1 / (x[1] - x[0])) + fx1 = jnp.take(f, jnp.array([-1]), axis, mode="wrap") - jnp.take( + f, jnp.array([-2]), axis, mode="wrap" + ) + fx1 *= jnp.where(x[-1] == x[-2], 0, 1 / (x[-1] - x[-2])) + + fx = (1 - c) * jnp.concatenate([fx0, df, fx1], axis=axis) + return fx + + +def _monotonic(x, f, axis, zero_slope): + f = jnp.moveaxis(f, axis, 0) + fshp = f.shape + if f.ndim == 1: + # So that _edge_case doesn't end up assigning to scalars + x = x[:, None] + f = f[:, None] + hk = x[1:] - x[:-1] + df = jnp.diff(f, axis=axis) + hki = jnp.where(hk == 0, 0, 1 / hk) + if df.ndim > hki.ndim: + hki = jnp.expand_dims(hki, tuple(range(1, df.ndim))) + hki = jnp.moveaxis(hki, 0, axis) + + mk = hki * df + + smk = jnp.sign(mk) + condition = (smk[1:, :] != smk[:-1, :]) | (mk[1:, :] == 0) | (mk[:-1, :] == 0) + + w1 = 2 * hk[1:] + hk[:-1] + w2 = hk[1:] + 2 * hk[:-1] + + if df.ndim > w1.ndim: + w1 = jnp.expand_dims(w1, tuple(range(1, df.ndim))) + w1 = jnp.moveaxis(w1, 0, axis) + w2 = jnp.expand_dims(w2, tuple(range(1, df.ndim))) + w2 = jnp.moveaxis(w2, 0, axis) + + whmean = (w1 / mk[:-1, :] + w2 / mk[1:, :]) / (w1 + w2) + + dk = jnp.where(condition, 0, 1.0 / whmean) + + if zero_slope: + d0 = jnp.zeros((1, dk.shape[1])) + d1 = jnp.zeros((1, dk.shape[1])) + + else: + # special case endpoints, as suggested in + # Cleve Moler, Numerical Computing with MATLAB, Chap 3.6 (pchiptx.m) + def _edge_case(h0, h1, m0, m1): + # one-sided three-point estimate for the derivative + d = ((2 * h0 + h1) * m0 - h0 * m1) / (h0 + h1) + + # try to preserve shape + mask = jnp.sign(d) != jnp.sign(m0) + mask2 = (jnp.sign(m0) != jnp.sign(m1)) & (jnp.abs(d) > 3.0 * jnp.abs(m0)) + mmm = (~mask) & mask2 + + d = jnp.where(mask, 0.0, d) + d = jnp.where(mmm, 3.0 * m0, d) + return d + + hk = 1 / hki + d0 = _edge_case(hk[0, :], hk[1, :], mk[0, :], mk[1, :])[None] + d1 = _edge_case(hk[-1, :], hk[-2, :], mk[-1, :], mk[-2, :])[None] + + dk = jnp.concatenate([d0, dk, d1]) + dk = dk.reshape(fshp) + return dk.reshape(fshp) diff --git a/interpax/_spline.py b/interpax/_spline.py index e949323..6fe10f4 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -7,10 +7,11 @@ import equinox as eqx import jax import jax.numpy as jnp -import lineax as lx import numpy as np from jax import jit +from ._coefs import A_BICUBIC, A_CUBIC, A_TRICUBIC +from ._fd_derivs import approx_df from .utils import errorif, isbool CUBIC_METHODS = ("cubic", "cubic2", "cardinal", "catmull-rom") @@ -1161,349 +1162,3 @@ def noclip(fq, *_): ) return fq - - -@partial(jit, static_argnames=("method", "axis")) -def approx_df( - x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs -): - """Approximates first derivatives using cubic spline interpolation. - - Parameters - ---------- - x : ndarray, shape(Nx,) - coordinates of known function values ("knots") - f : ndarray - Known function values. Should have length ``Nx`` along axis=axis - method : str - method of approximation - - - ``'cubic'``: C1 cubic splines (aka local splines) - - ``'cubic2'``: C2 cubic splines (aka natural splines) - - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass - keyword parameter ``c`` in float[0,1] to specify tension - - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the - data, and will not introduce new extrema in the interpolated points - - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at - both endpoints - - axis : int - Axis along which f is varying. - - Returns - ------- - df : ndarray, shape(f.shape) - First derivative of f with respect to x. - - """ - if method == "cubic": - dx = jnp.diff(x) - df = jnp.diff(f, axis=axis) - dxi = jnp.where(dx == 0, 0, 1 / dx) - if df.ndim > dxi.ndim: - dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) - dxi = jnp.moveaxis(dxi, 0, axis) - df = dxi * df - fx = jnp.concatenate( - [ - jnp.take(df, jnp.array([0]), axis, mode="wrap"), - 1 - / 2 - * ( - jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - + jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - ), - jnp.take(df, jnp.array([-1]), axis, mode="wrap"), - ], - axis=axis, - ) - return fx - - elif method == "cubic2": - dx = jnp.diff(x) - df = jnp.diff(f, axis=axis) - if df.ndim > dx.ndim: - dx = jnp.expand_dims(dx, tuple(range(1, df.ndim))) - dx = jnp.moveaxis(dx, 0, axis) - dxi = jnp.where(dx == 0, 0, 1 / dx) - df = dxi * df - - one = jnp.array([1.0]) - dxflat = dx.flatten() - diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one]) - upper_diag = jnp.concatenate([one, dxflat[:-1]]) - lower_diag = jnp.concatenate([dxflat[1:], one]) - - A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag) - - b = jnp.concatenate( - [ - 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), - 3 - * ( - jnp.take(dx, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - * jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - + jnp.take(dx, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - * jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - ), - 2 * jnp.take(df, jnp.array([-1]), axis, mode="wrap"), - ], - axis=axis, - ) - ba = jnp.moveaxis(b, axis, 0) - br = ba.reshape((b.shape[axis], -1)) - solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value - fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T - fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) - return fx - - elif method in ["cardinal", "catmull-rom"]: - dx = x[2:] - x[:-2] - df = jnp.take(f, jnp.arange(2, f.shape[axis]), axis, mode="wrap") - jnp.take( - f, jnp.arange(0, f.shape[axis] - 2), axis, mode="wrap" - ) - dxi = jnp.where(dx == 0, 0, 1 / dx) - if df.ndim > dxi.ndim: - dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim))) - dxi = jnp.moveaxis(dxi, 0, axis) - df = dxi * df - fx0 = jnp.take(f, jnp.array([1]), axis, mode="wrap") - jnp.take( - f, jnp.array([0]), axis, mode="wrap" - ) - fx0 *= jnp.where(x[0] == x[1], 0, 1 / (x[1] - x[0])) - fx1 = jnp.take(f, jnp.array([-1]), axis, mode="wrap") - jnp.take( - f, jnp.array([-2]), axis, mode="wrap" - ) - fx1 *= jnp.where(x[-1] == x[-2], 0, 1 / (x[-1] - x[-2])) - - if method == "cardinal": - c = kwargs.get("c", 0) - else: - c = 0 - fx = (1 - c) * jnp.concatenate([fx0, df, fx1], axis=axis) - return fx - - elif method in ["monotonic", "monotonic-0"]: - f = jnp.moveaxis(f, axis, 0) - fshp = f.shape - if f.ndim == 1: - # So that _edge_case doesn't end up assigning to scalars - x = x[:, None] - f = f[:, None] - hk = x[1:] - x[:-1] - df = jnp.diff(f, axis=axis) - hki = jnp.where(hk == 0, 0, 1 / hk) - if df.ndim > hki.ndim: - hki = jnp.expand_dims(hki, tuple(range(1, df.ndim))) - hki = jnp.moveaxis(hki, 0, axis) - - mk = hki * df - - smk = jnp.sign(mk) - condition = (smk[1:, :] != smk[:-1, :]) | (mk[1:, :] == 0) | (mk[:-1, :] == 0) - - w1 = 2 * hk[1:] + hk[:-1] - w2 = hk[1:] + 2 * hk[:-1] - - if df.ndim > w1.ndim: - w1 = jnp.expand_dims(w1, tuple(range(1, df.ndim))) - w1 = jnp.moveaxis(w1, 0, axis) - w2 = jnp.expand_dims(w2, tuple(range(1, df.ndim))) - w2 = jnp.moveaxis(w2, 0, axis) - - whmean = (w1 / mk[:-1, :] + w2 / mk[1:, :]) / (w1 + w2) - - dk = jnp.where(condition, 0, 1.0 / whmean) - - if method == "monotonic-0": - d0 = jnp.zeros((1, dk.shape[1])) - d1 = jnp.zeros((1, dk.shape[1])) - - else: - # special case endpoints, as suggested in - # Cleve Moler, Numerical Computing with MATLAB, Chap 3.6 (pchiptx.m) - def _edge_case(h0, h1, m0, m1): - # one-sided three-point estimate for the derivative - d = ((2 * h0 + h1) * m0 - h0 * m1) / (h0 + h1) - - # try to preserve shape - mask = jnp.sign(d) != jnp.sign(m0) - mask2 = (jnp.sign(m0) != jnp.sign(m1)) & ( - jnp.abs(d) > 3.0 * jnp.abs(m0) - ) - mmm = (~mask) & mask2 - - d = jnp.where(mask, 0.0, d) - d = jnp.where(mmm, 3.0 * m0, d) - return d - - hk = 1 / hki - d0 = _edge_case(hk[0, :], hk[1, :], mk[0, :], mk[1, :])[None] - d1 = _edge_case(hk[-1, :], hk[-2, :], mk[-1, :], mk[-2, :])[None] - - dk = jnp.concatenate([d0, dk, d1]) - dk = dk.reshape(fshp) - return dk.reshape(fshp) - - else: # method passed in does not use df from this function, just return 0 - return jnp.zeros_like(f) - - -# fmt: off -A_TRICUBIC = np.array([ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 9,-9,-9, 9, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 4, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 6,-6, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2,-2,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 6,-6, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -2,-1,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 4,-4,-4, 4, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9,-9, 9, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 6, 3,-6,-3, 0, 0, 0, 0, 6,-6, 3,-3, 0, 0, 0, 0, 4, 2, 2, 1, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-3,-3, 3, 3, 0, 0, 0, 0,-4, 4,-2, 2, 0, 0, 0, 0,-2,-2,-1,-1, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 6,-6, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-4,-2, 4, 2, 0, 0, 0, 0,-3, 3,-3, 3, 0, 0, 0, 0,-2,-1,-2,-1, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4,-4, 4, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 2, 2,-2,-2, 0, 0, 0, 0, 2,-2, 2,-2, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], # noqa: E501 - [-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 9,-9, 0, 0,-9, 9, 0, 0, 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 0, 0, 6,-6, 0, 0,-3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0, 0, 0,-1, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9,-9, 0, 0,-9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 6, 3, 0, 0,-6,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,-6, 0, 0, 3,-3, 0, 0, 4, 2, 0, 0, 2, 1, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -3,-3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 4, 0, 0,-2, 2, 0, 0,-2,-2, 0, 0,-1,-1, 0, 0], # noqa: E501 - [ 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 9, 0,-9, 0,-9, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 6, 0, 3, 0,-6, 0,-3, 0, 6, 0,-6, 0, 3, 0,-3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 1, 0], # noqa: E501 - [-27,27,27,-27,27,-27,-27,27,-18,-9,18, 9,18, 9,-18,-9,-18,18,-9, 9,18,-18, 9,-9,-18,18,18,-18,-9, 9, 9, # noqa: E501 - -9,-12,-6,-6,-3,12, 6, 6, 3,-12,-6,12, 6,-6,-3, 6, 3,-12,12,-6, 6,-6, 6,-3, 3,-8,-4,-4,-2,-4,-2,-2,-1], # noqa: E501 - [18,-18,-18,18,-18,18,18,-18, 9, 9,-9,-9,-9,-9, 9, 9,12,-12, 6,-6,-12,12,-6, 6,12,-12,-12,12, 6,-6,-6, # noqa: E501 - 6, 6, 6, 3, 3,-6,-6,-3,-3, 6, 6,-6,-6, 3, 3,-3,-3, 8,-8, 4,-4, 4,-4, 2,-2, 4, 4, 2, 2, 2, 2, 1, 1], # noqa: E501 - [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -3, 0,-3, 0, 3, 0, 3, 0,-4, 0, 4, 0,-2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-2, 0,-1, 0,-1, 0], # noqa: E501 - [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6, 9,-9, 9,-9,-9, 9,-9, 9,12,-12,-12,12, 6,-6,-6, # noqa: E501 - 6, 6, 3, 6, 3,-6,-3,-6,-3, 8, 4,-8,-4, 4, 2,-4,-2, 6,-6, 6,-6, 3,-3, 3,-3, 4, 2, 4, 2, 2, 1, 2, 1], # noqa: E501 - [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-6, 6,-6, 6, 6,-6, 6,-6,-8, 8, 8,-8,-4, 4, 4,-4, # noqa: E501 - -3,-3,-3,-3, 3, 3, 3, 3,-4,-4, 4, 4,-2,-2, 2, 2,-4, 4,-4, 4,-2, 2,-2, 2,-2,-2,-2,-2,-1,-1,-1,-1], # noqa: E501 - [ 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [-6, 6, 0, 0, 6,-6, 0, 0,-4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 4,-4, 0, 0,-4, 4, 0, 0, 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 2, 0, 0, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-6, 6, 0, 0, 6,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -4,-2, 0, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-3, 3, 0, 0,-3, 3, 0, 0,-2,-1, 0, 0,-2,-1, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,-4, 0, 0,-4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 2, 2, 0, 0,-2,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,-2, 0, 0, 2,-2, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], # noqa: E501 - [-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0,-4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0,-6, 0, 6, 0, 6, 0,-6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - -4, 0,-2, 0, 4, 0, 2, 0,-3, 0, 3, 0,-3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,-2, 0,-1, 0,-2, 0,-1, 0], # noqa: E501 - [18,-18,-18,18,-18,18,18,-18,12, 6,-12,-6,-12,-6,12, 6,12,-12, 6,-6,-12,12,-6, 6, 9,-9,-9, 9, 9,-9,-9, # noqa: E501 - 9, 8, 4, 4, 2,-8,-4,-4,-2, 6, 3,-6,-3, 6, 3,-6,-3, 6,-6, 3,-3, 6,-6, 3,-3, 4, 2, 2, 1, 4, 2, 2, 1], # noqa: E501 - [-12,12,12,-12,12,-12,-12,12,-6,-6, 6, 6, 6, 6,-6,-6,-8, 8,-4, 4, 8,-8, 4,-4,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 - -4,-4,-2,-2, 4, 4, 2, 2,-3,-3, 3, 3,-3,-3, 3, 3,-4, 4,-2, 2,-4, 4,-2, 2,-2,-2,-1,-1,-2,-2,-1,-1], # noqa: E501 - [ 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, # noqa: E501 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # noqa: E501 - [ 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,-4, 0,-4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # noqa: E501 - 2, 0, 2, 0,-2, 0,-2, 0, 2, 0,-2, 0, 2, 0,-2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], # noqa: E501 - [-12,12,12,-12,12,-12,-12,12,-8,-4, 8, 4, 8, 4,-8,-4,-6, 6,-6, 6, 6,-6, 6,-6,-6, 6, 6,-6,-6, 6, 6,-6, # noqa: E501 - -4,-2,-4,-2, 4, 2, 4, 2,-4,-2, 4, 2,-4,-2, 4, 2,-3, 3,-3, 3,-3, 3,-3, 3,-2,-1,-2,-1,-2,-1,-2,-1], # noqa: E501 - [ 8,-8,-8, 8,-8, 8, 8,-8, 4, 4,-4,-4,-4,-4, 4, 4, 4,-4, 4,-4,-4, 4,-4, 4, 4,-4,-4, 4, 4,-4,-4, 4, # noqa: E501 - 2, 2, 2, 2,-2,-2,-2,-2, 2, 2,-2,-2, 2, 2,-2,-2, 2,-2, 2,-2, 2,-2, 2,-2, 1, 1, 1, 1, 1, 1, 1, 1] # noqa: E501 -]) - -A_BICUBIC = np.array([ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [-3, 3, 0, 0, -2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [2, -2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, -2, -1, 0, 0 ], - [0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 1, 1, 0, 0 ], - [-3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, -3, 0, 3, 0, 0, 0, 0, 0, -2, 0, -1, 0 ], - [9, -9, -9, 9, 6, 3, -6, -3, 6, -6, 3, -3, 4, 2, 2, 1 ], - [-6, 6, 6, -6, -3, -3, 3, 3, -4, 4, -2, 2, -2, -2, -1, -1 ], - [2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0 ], - [0, 0, 0, 0, 2, 0, -2, 0, 0, 0, 0, 0, 1, 0, 1, 0 ], - [-6, 6, 6, -6, -4, -2, 4, 2, -3, 3, -3, 3, -2, -1, -2, -1 ], - [4, -4, -4, 4, 2, 2, -2, -2, 2, -2, 2, -2, 1, 1, 1, 1] -]) - -A_CUBIC = np.array([ - [1, 0, 0, 0], - [0, 0, 1, 0], - [-3, 3, -2, -1], - [2, -2, 1, 1], -]) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index 5f101ff..5e2ca53 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -434,41 +434,11 @@ def _finite_difference(self, f, x, eps=1e-8): @pytest.mark.unit def test_ad_interp1d(self): """Test AD of different 1d interpolation methods.""" - xp = np.linspace(0, 2 * np.pi, 100) - x = np.linspace(0, 2 * np.pi, 200) + xp = np.linspace(0, 2 * np.pi, 10) + x = np.linspace(0, 2 * np.pi, 20) f = lambda x: np.sin(x) fp = f(xp) - for method in ["cubic", "cubic2", "cardinal"]: - interp1 = lambda xq: interp1d(xq, xp, fp, method=method) - interp2 = lambda xq: Interpolator1D(xp, fp, method=method)(xq) - - f1 = jnp.vectorize(jax.grad(interp1))(x) - f2 = jnp.vectorize(jax.grad(interp2))(x) - - np.testing.assert_allclose(f1, np.cos(x), rtol=1e-2, atol=1e-2) - np.testing.assert_allclose(f1, f2) - - for method in ["cubic", "cubic2", "cardinal", "monotonic"]: - - interp1 = lambda fp: interp1d(x, xp, fp, method=method) - interp2 = lambda fp: Interpolator1D(xp, fp, method=method)(x) - - jacf1 = jax.jacfwd(interp1)(fp) - jacf2 = jax.jacfwd(interp2)(fp) - - jacr1 = jax.jacrev(interp1)(fp) - jacr2 = jax.jacrev(interp2)(fp) - - jacd1 = self._finite_difference(interp1, fp) - jacd2 = self._finite_difference(interp2, fp) - - np.testing.assert_allclose(jacf1, jacf2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacr1, jacr2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacf1, jacr1, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacf1, jacd1, rtol=1e-6, atol=1e-6) - np.testing.assert_allclose(jacf2, jacd2, rtol=1e-6, atol=1e-6) - for method in ["cubic", "cubic2", "cardinal", "monotonic"]: interp1 = lambda xp: interp1d(x, xp, fp, method=method) @@ -493,40 +463,15 @@ def test_ad_interp1d(self): @pytest.mark.unit def test_ad_interp2d(self): """Test AD of different 2d interpolation methods.""" - xp = np.linspace(0, 4 * np.pi, 40) - yp = np.linspace(0, 2 * np.pi, 40) - y = np.linspace(0, 2 * np.pi, 100) - x = np.linspace(0, 2 * np.pi, 100) + xp = np.linspace(0, 4 * np.pi, 20) + yp = np.linspace(0, 2 * np.pi, 20) + y = np.linspace(0, 2 * np.pi, 30) + x = np.linspace(0, 2 * np.pi, 30) xxp, yyp = np.meshgrid(xp, yp, indexing="ij") f = lambda x, y: np.sin(x) * np.cos(y) fp = f(xxp, yyp) - for method in ["cubic", "cubic2", "cardinal"]: - interp1 = lambda xq, yq: interp2d(xq, yq, xp, yp, fp, method=method) - interp2 = lambda xq, yq: Interpolator2D(xp, yp, fp, method=method)(xq, yq) - - f1 = jnp.vectorize(jax.grad(interp1))(x, y) - f2 = jnp.vectorize(jax.grad(interp2))(x, y) - - np.testing.assert_allclose(f1, np.cos(x) * np.cos(y), rtol=3e-2, atol=3e-2) - np.testing.assert_allclose(f1, f2) - - for method in ["cubic", "cubic2", "cardinal"]: - - interp1 = lambda fp: interp2d(x, y, xp, yp, fp, method=method) - interp2 = lambda fp: Interpolator2D(xp, yp, fp, method=method)(x, y) - - jacf1 = jax.jacfwd(interp1)(fp) - jacf2 = jax.jacfwd(interp2)(fp) - - jacr1 = jax.jacrev(interp1)(fp) - jacr2 = jax.jacrev(interp2)(fp) - - np.testing.assert_allclose(jacf1, jacf2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacr1, jacr2, rtol=1e-14, atol=1e-14) - np.testing.assert_allclose(jacf1, jacr1, rtol=1e-14, atol=1e-14) - for method in ["cubic", "cubic2", "cardinal"]: interp1 = lambda xp: interp2d(x, y, xp, yp, fp, method=method) @@ -551,48 +496,17 @@ def test_ad_interp2d(self): @pytest.mark.unit def test_ad_interp3d(self): """Test AD of different 3d interpolation methods.""" - xp = np.linspace(0, np.pi, 20) - yp = np.linspace(0, 2 * np.pi, 30) - zp = np.linspace(0, 1, 10) - x = np.linspace(0, np.pi, 100) - y = np.linspace(0, 2 * np.pi, 100) - z = np.linspace(0, 1, 100) + xp = np.linspace(0, np.pi, 10) + yp = np.linspace(0, 2 * np.pi, 15) + zp = np.linspace(0, 1, 12) + x = np.linspace(0, np.pi, 13) + y = np.linspace(0, 2 * np.pi, 13) + z = np.linspace(0, 1, 13) xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij") f = lambda x, y, z: np.sin(x) * np.cos(y) * z**2 fp = f(xxp, yyp, zzp) - for method in ["cubic", "cubic2", "cardinal"]: - interp1 = lambda xq, yq, zq: interp3d( - xq, yq, zq, xp, yp, zp, fp, method=method - ) - interp2 = lambda xq, yq, zq: Interpolator3D(xp, yp, zp, fp, method=method)( - xq, yq, zq - ) - - f1 = jnp.vectorize(jax.grad(interp1))(x, y, z) - f2 = jnp.vectorize(jax.grad(interp2))(x, y, z) - - np.testing.assert_allclose( - f1, np.cos(x) * np.cos(y) * z**2, rtol=3e-2, atol=3e-2 - ) - np.testing.assert_allclose(f1, f2) - - for method in ["cubic", "cubic2", "cardinal"]: - - interp1 = lambda fp: interp3d(x, y, z, xp, yp, zp, fp, method=method) - interp2 = lambda fp: Interpolator3D(xp, yp, zp, fp, method=method)(x, y, z) - - jacf1 = jax.jacfwd(interp1)(fp) - jacf2 = jax.jacfwd(interp2)(fp) - - jacr1 = jax.jacrev(interp1)(fp) - jacr2 = jax.jacrev(interp2)(fp) - - np.testing.assert_allclose(jacf1, jacf2, rtol=1e-12, atol=1e-12) - np.testing.assert_allclose(jacr1, jacr2, rtol=1e-12, atol=1e-12) - np.testing.assert_allclose(jacf1, jacr1, rtol=1e-12, atol=1e-12) - for method in ["cubic", "cubic2", "cardinal"]: interp1 = lambda xp: interp3d(x, y, z, xp, yp, zp, fp, method=method)