diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py index a8ba650..895f506 100644 --- a/interpax/_fd_derivs.py +++ b/interpax/_fd_derivs.py @@ -4,8 +4,9 @@ import jax.numpy as jnp from jax import jit +from .utils import errorif + -@partial(jit, static_argnames=("method", "axis")) def approx_df( x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs ): @@ -21,7 +22,8 @@ def approx_df( method of approximation - ``'cubic'``: C1 cubic splines (aka local splines) - - ``'cubic2'``: C2 cubic splines (aka natural splines) + - ``'cubic2'``: C2 cubic splines. Can also pass kwarg ``bc_type``, same as + ``scipy.interpolate.CubicSpline`` - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass keyword parameter ``c`` in float[0,1] to specify tension @@ -29,6 +31,7 @@ def approx_df( data, and will not introduce new extrema in the interpolated points - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural axis : int Axis along which f is varying. @@ -39,18 +42,25 @@ def approx_df( First derivative of f with respect to x. """ + return _approx_df(x, f, method, axis, **kwargs) + + +@partial(jit, static_argnames=("method", "axis", "bc_type")) +def _approx_df(x, f, method, axis, c=0, bc_type="not-a-knot"): if method == "cubic": - out = _cubic1(x, f, axis, **kwargs) + out = _cubic1(x, f, axis) elif method == "cubic2": - out = _cubic2(x, f, axis) + out = _cubic2(x, f, axis, bc_type=bc_type) elif method == "cardinal": - out = _cardinal(x, f, axis, **kwargs) + out = _cardinal(x, f, axis, c=c) elif method == "catmull-rom": - out = _cardinal(x, f, axis, **kwargs) + out = _cardinal(x, f, axis, c=0) elif method == "monotonic": - out = _monotonic(x, f, axis, False, **kwargs) + out = _monotonic(x, f, axis, False) elif method == "monotonic-0": - out = _monotonic(x, f, axis, True, **kwargs) + out = _monotonic(x, f, axis, True) + elif method == "akima": + out = _akima(x, f, axis) elif method in ("nearest", "linear"): out = jnp.zeros_like(f) else: @@ -82,41 +92,156 @@ def _cubic1(x, f, axis): return fx -def _cubic2(x, f, axis): +def _validate_bc(bc_type, expected_deriv_shape): + if isinstance(bc_type, str): + errorif(bc_type == "periodic", NotImplementedError) + bc_type = (bc_type, bc_type) + + else: + errorif( + len(bc_type) != 2, + ValueError, + "`bc_type` must contain 2 elements to specify start and end conditions.", + ) + + errorif( + "periodic" in bc_type, + ValueError, + "'periodic' `bc_type` is defined for both " + + "curve ends and cannot be used with other " + + "boundary conditions.", + ) + + validated_bc = [] + for bc in bc_type: + if isinstance(bc, str): + errorif(bc_type == "periodic", NotImplementedError) + if bc == "clamped": + validated_bc.append((1, jnp.zeros(expected_deriv_shape))) + elif bc == "natural": + validated_bc.append((2, jnp.zeros(expected_deriv_shape))) + elif bc in ["not-a-knot", "periodic"]: + validated_bc.append(bc) + else: + raise ValueError(f"bc_type={bc} is not allowed.") + else: + try: + deriv_order, deriv_value = bc + except Exception as e: + raise ValueError( + "A specified derivative value must be " + "given in the form (order, value)." + ) from e + + if deriv_order not in [1, 2]: + raise ValueError("The specified derivative order must " "be 1 or 2.") + + deriv_value = jnp.asarray(deriv_value) + if deriv_value.shape != expected_deriv_shape: + raise ValueError( + "`deriv_value` shape {} is not the expected one {}.".format( + deriv_value.shape, expected_deriv_shape + ) + ) + validated_bc.append((deriv_order, deriv_value)) + return validated_bc + + +def _cubic2(x, f, axis, bc_type): + f = jnp.moveaxis(f, axis, 0) + bc = _validate_bc(bc_type, f.shape[1:]) dx = jnp.diff(x) - df = jnp.diff(f, axis=axis) - if df.ndim > dx.ndim: - dx = jnp.expand_dims(dx, tuple(range(1, df.ndim))) - dx = jnp.moveaxis(dx, 0, axis) - dxi = jnp.where(dx == 0, 0, 1 / dx) + df = jnp.diff(f, axis=0) + dxr = dx.reshape([dx.shape[0]] + [1] * (f.ndim - 1)) + dxi = jnp.where(dxr == 0, 0, 1 / jnp.where(dxr == 0, 1, dxr)) df = dxi * df + n = len(f) - one = jnp.array([1.0]) - dxflat = dx.flatten() - diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one]) - upper_diag = jnp.concatenate([one, dxflat[:-1]]) - lower_diag = jnp.concatenate([dxflat[1:], one]) + # If bc is 'not-a-knot' this change is just a convention. + # If bc is 'periodic' then we already checked that y[0] == y[-1], + # and the spline is just a constant, we handle this case in the + # same way by setting the first derivatives to slope, which is 0. + if n == 2: + if bc[0] in ["not-a-knot", "periodic"]: + bc[0] = (1, df[0]) + if bc[1] in ["not-a-knot", "periodic"]: + bc[1] = (1, df[0]) - A = jnp.diag(diag) + jnp.diag(upper_diag, k=1) + jnp.diag(lower_diag, k=-1) - b = jnp.concatenate( - [ - 2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"), - 3 - * ( - jnp.take(dx, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - * jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - + jnp.take(dx, jnp.arange(1, df.shape[axis]), axis, mode="wrap") - * jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap") - ), - 2 * jnp.take(df, jnp.array([-1]), axis, mode="wrap"), - ], - axis=axis, - ) - ba = jnp.moveaxis(b, axis, 0) - br = ba.reshape((b.shape[axis], -1)) - solve = lambda b: jnp.linalg.solve(A, b) - fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T - fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis) + # This is a special case, when both conditions are 'not-a-knot' + # and n == 3. In this case 'not-a-knot' can't be handled regularly + # as the both conditions are identical. We handle this case by + # constructing a parabola passing through given points. + if n == 3 and bc[0] == "not-a-knot" and bc[1] == "not-a-knot": + A = jnp.zeros((3, 3)) # This is a standard matrix. + b = jnp.empty((3,) + f.shape[1:], dtype=f.dtype) + + A = A.at[0, 0].set(1) + A = A.at[0, 1].set(1) + A = A.at[1, 0].set(dx[1]) + A = A.at[1, 1].set(2 * (dx[0] + dx[1])) + A = A.at[1, 2].set(dx[0]) + A = A.at[2, 1].set(1) + A = A.at[2, 2].set(1) + + b = b.at[0].set(2 * df[0]) + b = b.at[1].set(3 * (dxr[0] * df[1] + dxr[1] * df[0])) + b = b.at[2].set(2 * df[1]) + + s = jnp.linalg.solve(A, b) + fx = jnp.moveaxis(s, 0, axis) + + else: + + # Find derivative values at each x[i] by solving a tridiagonal + # system. + diag = jnp.zeros(n) + diag = diag.at[1:-1].set(2 * (dx[:-1] + dx[1:])) + upper_diag = jnp.zeros(n - 1) + upper_diag = upper_diag.at[1:].set(dx[:-1]) + lower_diag = jnp.zeros(n - 1) + lower_diag = lower_diag.at[:-1].set(dx[1:]) + b = jnp.zeros((n,) + f.shape[1:], dtype=f.dtype) + b = b.at[1:-1].set(3 * (dxr[1:] * df[:-1] + dxr[:-1] * df[1:])) + + bc_start, bc_end = bc + + if bc_start == "not-a-knot": + d = x[2] - x[0] + diag = diag.at[0].set(dx[1]) + upper_diag = upper_diag.at[0].set(d) + b = b.at[0].set( + ((dxr[0] + 2 * d) * dxr[1] * df[0] + dxr[0] ** 2 * df[1]) / d + ) + elif bc_start[0] == 1: + diag = diag.at[0].set(1) + upper_diag = upper_diag.at[0].set(0) + b = b.at[0].set(bc_start[1]) + elif bc_start[0] == 2: + diag = diag.at[0].set(2 * dx[0]) + upper_diag = upper_diag.at[0].set(dx[0]) + b = b.at[0].set(-0.5 * bc_start[1] * dx[0] ** 2 + 3 * (f[1] - f[0])) + + if bc_end == "not-a-knot": + d = x[-1] - x[-3] + diag = diag.at[-1].set(dx[-2]) + lower_diag = lower_diag.at[-1].set(d) + b = b.at[-1].set( + (dxr[-1] ** 2 * df[-2] + (2 * d + dxr[-1]) * dxr[-2] * df[-1]) / d + ) + elif bc_end[0] == 1: + diag = diag.at[-1].set(1) + lower_diag = lower_diag.at[-1].set(0) + b = b.at[-1].set(bc_end[1]) + elif bc_end[0] == 2: + diag = diag.at[-1].set(2 * dx[-1]) + lower_diag = lower_diag.at[-1].set(dx[-1]) + b = b.at[-1].set(0.5 * bc_end[1] * dx[-1] ** 2 + 3 * (f[-1] - f[-2])) + + A = jnp.diag(diag) + jnp.diag(upper_diag, k=1) + jnp.diag(lower_diag, k=-1) + + solve = lambda b: jnp.linalg.solve(A, b) + fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T + fx = jnp.moveaxis(fx, 0, axis) return fx @@ -151,11 +276,10 @@ def _monotonic(x, f, axis, zero_slope): x = x[:, None] f = f[:, None] hk = x[1:] - x[:-1] - df = jnp.diff(f, axis=axis) + df = jnp.diff(f, axis=0) hki = jnp.where(hk == 0, 0, 1 / hk) if df.ndim > hki.ndim: hki = jnp.expand_dims(hki, tuple(range(1, df.ndim))) - hki = jnp.moveaxis(hki, 0, axis) mk = hki * df @@ -167,9 +291,7 @@ def _monotonic(x, f, axis, zero_slope): if df.ndim > w1.ndim: w1 = jnp.expand_dims(w1, tuple(range(1, df.ndim))) - w1 = jnp.moveaxis(w1, 0, axis) w2 = jnp.expand_dims(w2, tuple(range(1, df.ndim))) - w2 = jnp.moveaxis(w2, 0, axis) whmean = (w1 / mk[:-1, :] + w2 / mk[1:, :]) / (w1 + w2) @@ -201,4 +323,40 @@ def _edge_case(h0, h1, m0, m1): dk = jnp.concatenate([d0, dk, d1]) dk = dk.reshape(fshp) - return dk.reshape(fshp) + return jnp.moveaxis(dk, 0, axis) + + +def _akima(x, f, axis): + # Original implementation in MATLAB by N. Shamsundar (BSD licensed), see + # https://www.mathworks.com/matlabcentral/fileexchange/1814-akima-interpolation + dx = jnp.diff(x) + f = jnp.moveaxis(f, axis, 0) + # determine slopes between breakpoints + m = jnp.empty((x.size + 3,) + f.shape[1:]) + dx = dx[(slice(None),) + (None,) * (f.ndim - 1)] + mask = dx == 0 + dx = jnp.where(mask, 1, dx) + dxi = jnp.where(mask, 0.0, 1 / dx) + m = m.at[2:-2].set(jnp.diff(f, axis=0) * dxi) + + # add two additional points on the left ... + m = m.at[1].set(2.0 * m[2] - m[3]) + m = m.at[0].set(2.0 * m[1] - m[2]) + # ... and on the right + m = m.at[-2].set(2.0 * m[-3] - m[-4]) + m = m.at[-1].set(2.0 * m[-2] - m[-3]) + + # df = derivative of f at x + # df = (|m4 - m3| * m2 + |m2 - m1| * m3) / (|m4 - m3| + |m2 - m1|) + # if m1 == m2 != m3 == m4, the slope at the breakpoint is not + # defined. Use instead 1/2(m2 + m3) + dm = jnp.abs(jnp.diff(m, axis=0)) + m2 = m[1:-2] + m3 = m[2:-1] + m4m3 = dm[2:] # |m4 - m3| + m2m1 = dm[:-2] # |m2 - m1| + f12 = m4m3 + m2m1 + mask = f12 > 1e-9 * jnp.max(f12, initial=-jnp.inf) + df = (m4m3 * m2 + m2m1 * m3) / jnp.where(mask, f12, 1.0) + df = jnp.where(mask, df, 0.5 * (m[3:] + m[:-3])) + return jnp.moveaxis(df, 0, axis) diff --git a/interpax/_spline.py b/interpax/_spline.py index 6fe10f4..4fbb6a3 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -14,9 +14,17 @@ from ._fd_derivs import approx_df from .utils import errorif, isbool -CUBIC_METHODS = ("cubic", "cubic2", "cardinal", "catmull-rom") +CUBIC_METHODS = ( + "cubic", + "cubic2", + "cardinal", + "catmull-rom", + "akima", + "monotonic", + "monotonic-0", +) OTHER_METHODS = ("nearest", "linear") -METHODS_1D = CUBIC_METHODS + OTHER_METHODS + ("monotonic", "monotonic-0") +METHODS_1D = CUBIC_METHODS + OTHER_METHODS METHODS_2D = CUBIC_METHODS + OTHER_METHODS METHODS_3D = CUBIC_METHODS + OTHER_METHODS @@ -44,6 +52,7 @@ class Interpolator1D(eqx.Module): data, and will not introduce new extrema in the interpolated points - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural extrap : bool, float, array-like whether to extrapolate values beyond knots (True) or return nan (False), @@ -149,6 +158,11 @@ class Interpolator2D(eqx.Module): - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass keyword parameter ``c`` in float[0,1] to specify tension + - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the + data, and will not introduce new extrema in the interpolated points + - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at + both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural extrap : bool, float, array-like whether to extrapolate values beyond knots (True) or return nan (False), @@ -273,6 +287,11 @@ class Interpolator3D(eqx.Module): - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass keyword parameter ``c`` in float[0,1] to specify tension + - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the + data, and will not introduce new extrema in the interpolated points + - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at + both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural extrap : bool, float, array-like whether to extrapolate values beyond knots (True) or return nan (False), @@ -448,6 +467,7 @@ def interp1d( data, and will not introduce new extrema in the interpolated points - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural derivative : int >= 0 derivative order to calculate @@ -532,7 +552,7 @@ def derivative2(): fq = jax.lax.switch(derivative, [derivative0, derivative1, derivative2]) - elif method in (CUBIC_METHODS + ("monotonic", "monotonic-0")): + elif method in CUBIC_METHODS: i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1) if fx is None: @@ -595,6 +615,11 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass keyword parameter ``c`` in float[0,1] to specify tension + - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the + data, and will not introduce new extrema in the interpolated points + - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at + both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural derivative : int >= 0 or array-like, shape(2,) derivative order to calculate in x, y. Use a single value for the same in both @@ -805,6 +830,11 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces - ``'catmull-rom'``: C1 cubic centripetal "tension" splines - ``'cardinal'``: C1 cubic general tension splines. If used, can also pass keyword parameter ``c`` in float[0,1] to specify tension + - ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the + data, and will not introduce new extrema in the interpolated points + - ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at + both endpoints + - ``'akima'``: C1 cubic splines that appear smooth and natural derivative : int >= 0, array-like, shape(3,) derivative order to calculate in x,y,z directions. Use a single value for the diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index 5e2ca53..f3c1441 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -65,6 +65,9 @@ def test_interp1d(self, x): fq = interp(x, xp, fp, method="monotonic-0") np.testing.assert_allclose(fq, f(x), rtol=1e-4, atol=1e-2) + fq = interp(x, xp, fp, method="akima") + np.testing.assert_allclose(fq, f(x), rtol=1e-6, atol=2e-5) + @pytest.mark.unit def test_interp1d_vector_valued(self): """Test for interpolating vector valued function.""" @@ -187,6 +190,12 @@ def test_interp2d(self, x, y): x, y, xp, yp, fp, method="cardinal", period=(2 * np.pi, 2 * np.pi) ) np.testing.assert_allclose(fq, f(x, y), rtol=rtol, atol=atol) + fq = interp(x, y, xp, yp, fp, method="akima", period=(2 * np.pi, 2 * np.pi)) + np.testing.assert_allclose(fq, f(x, y), rtol=rtol, atol=atol) + fq = interp( + x, y, xp, yp, fp, method="monotonic", period=(2 * np.pi, 2 * np.pi) + ) + np.testing.assert_allclose(fq, f(x, y), rtol=1e-2, atol=1e-3) @pytest.mark.unit def test_interp2d_vector_valued(self):