Skip to content

Commit

Permalink
Merge pull request #22 from f0uriest/rc/cleanup
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
f0uriest authored Mar 4, 2024
2 parents 1047afe + c2d0e7c commit 64b6b34
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 455 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion interpax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""interpax: interpolation and function approximation with JAX."""

from . import _version
from ._fd_derivs import approx_df
from ._fourier import fft_interp1d, fft_interp2d
from ._spline import (
Interpolator1D,
Interpolator2D,
Interpolator3D,
approx_df,
interp1d,
interp2d,
interp3d,
Expand Down
161 changes: 161 additions & 0 deletions interpax/_coefs.py

Large diffs are not rendered by default.

204 changes: 204 additions & 0 deletions interpax/_fd_derivs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit


@partial(jit, static_argnames=("method", "axis"))
def approx_df(
x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs
):
"""Approximates first derivatives using cubic spline interpolation.
Parameters
----------
x : ndarray, shape(Nx,)
coordinates of known function values ("knots")
f : ndarray
Known function values. Should have length ``Nx`` along axis=axis
method : str
method of approximation
- ``'cubic'``: C1 cubic splines (aka local splines)
- ``'cubic2'``: C2 cubic splines (aka natural splines)
- ``'catmull-rom'``: C1 cubic centripetal "tension" splines
- ``'cardinal'``: C1 cubic general tension splines. If used, can also pass
keyword parameter ``c`` in float[0,1] to specify tension
- ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the
data, and will not introduce new extrema in the interpolated points
- ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at
both endpoints
axis : int
Axis along which f is varying.
Returns
-------
df : ndarray, shape(f.shape)
First derivative of f with respect to x.
"""
if method == "cubic":
out = _cubic1(x, f, axis, **kwargs)
elif method == "cubic2":
out = _cubic2(x, f, axis)
elif method == "cardinal":
out = _cardinal(x, f, axis, **kwargs)
elif method == "catmull-rom":
out = _cardinal(x, f, axis, **kwargs)
elif method == "monotonic":
out = _monotonic(x, f, axis, False, **kwargs)
elif method == "monotonic-0":
out = _monotonic(x, f, axis, True, **kwargs)
elif method in ("nearest", "linear"):
out = jnp.zeros_like(f)
else:
raise ValueError(f"got unknown method {method}")
return out


def _cubic1(x, f, axis):
dx = jnp.diff(x)
df = jnp.diff(f, axis=axis)
dxi = jnp.where(dx == 0, 0, 1 / dx)
if df.ndim > dxi.ndim:
dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim)))
dxi = jnp.moveaxis(dxi, 0, axis)
df = dxi * df
fx = jnp.concatenate(
[
jnp.take(df, jnp.array([0]), axis, mode="wrap"),
1
/ 2
* (
jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap")
+ jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap")
),
jnp.take(df, jnp.array([-1]), axis, mode="wrap"),
],
axis=axis,
)
return fx


def _cubic2(x, f, axis):
dx = jnp.diff(x)
df = jnp.diff(f, axis=axis)
if df.ndim > dx.ndim:
dx = jnp.expand_dims(dx, tuple(range(1, df.ndim)))
dx = jnp.moveaxis(dx, 0, axis)
dxi = jnp.where(dx == 0, 0, 1 / dx)
df = dxi * df

one = jnp.array([1.0])
dxflat = dx.flatten()
diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one])
upper_diag = jnp.concatenate([one, dxflat[:-1]])
lower_diag = jnp.concatenate([dxflat[1:], one])

A = jnp.diag(diag) + jnp.diag(upper_diag, k=1) + jnp.diag(lower_diag, k=-1)
b = jnp.concatenate(
[
2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"),
3
* (
jnp.take(dx, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap")
* jnp.take(df, jnp.arange(1, df.shape[axis]), axis, mode="wrap")
+ jnp.take(dx, jnp.arange(1, df.shape[axis]), axis, mode="wrap")
* jnp.take(df, jnp.arange(0, df.shape[axis] - 1), axis, mode="wrap")
),
2 * jnp.take(df, jnp.array([-1]), axis, mode="wrap"),
],
axis=axis,
)
ba = jnp.moveaxis(b, axis, 0)
br = ba.reshape((b.shape[axis], -1))
solve = lambda b: jnp.linalg.solve(A, b)
fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T
fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis)
return fx


def _cardinal(x, f, axis, c=0):
dx = x[2:] - x[:-2]
df = jnp.take(f, jnp.arange(2, f.shape[axis]), axis, mode="wrap") - jnp.take(
f, jnp.arange(0, f.shape[axis] - 2), axis, mode="wrap"
)
dxi = jnp.where(dx == 0, 0, 1 / dx)
if df.ndim > dxi.ndim:
dxi = jnp.expand_dims(dxi, tuple(range(1, df.ndim)))
dxi = jnp.moveaxis(dxi, 0, axis)
df = dxi * df
fx0 = jnp.take(f, jnp.array([1]), axis, mode="wrap") - jnp.take(
f, jnp.array([0]), axis, mode="wrap"
)
fx0 *= jnp.where(x[0] == x[1], 0, 1 / (x[1] - x[0]))
fx1 = jnp.take(f, jnp.array([-1]), axis, mode="wrap") - jnp.take(
f, jnp.array([-2]), axis, mode="wrap"
)
fx1 *= jnp.where(x[-1] == x[-2], 0, 1 / (x[-1] - x[-2]))

fx = (1 - c) * jnp.concatenate([fx0, df, fx1], axis=axis)
return fx


def _monotonic(x, f, axis, zero_slope):
f = jnp.moveaxis(f, axis, 0)
fshp = f.shape
if f.ndim == 1:
# So that _edge_case doesn't end up assigning to scalars
x = x[:, None]
f = f[:, None]
hk = x[1:] - x[:-1]
df = jnp.diff(f, axis=axis)
hki = jnp.where(hk == 0, 0, 1 / hk)
if df.ndim > hki.ndim:
hki = jnp.expand_dims(hki, tuple(range(1, df.ndim)))
hki = jnp.moveaxis(hki, 0, axis)

mk = hki * df

smk = jnp.sign(mk)
condition = (smk[1:, :] != smk[:-1, :]) | (mk[1:, :] == 0) | (mk[:-1, :] == 0)

w1 = 2 * hk[1:] + hk[:-1]
w2 = hk[1:] + 2 * hk[:-1]

if df.ndim > w1.ndim:
w1 = jnp.expand_dims(w1, tuple(range(1, df.ndim)))
w1 = jnp.moveaxis(w1, 0, axis)
w2 = jnp.expand_dims(w2, tuple(range(1, df.ndim)))
w2 = jnp.moveaxis(w2, 0, axis)

whmean = (w1 / mk[:-1, :] + w2 / mk[1:, :]) / (w1 + w2)

dk = jnp.where(condition, 0, 1.0 / whmean)

if zero_slope:
d0 = jnp.zeros((1, dk.shape[1]))
d1 = jnp.zeros((1, dk.shape[1]))

else:
# special case endpoints, as suggested in
# Cleve Moler, Numerical Computing with MATLAB, Chap 3.6 (pchiptx.m)
def _edge_case(h0, h1, m0, m1):
# one-sided three-point estimate for the derivative
d = ((2 * h0 + h1) * m0 - h0 * m1) / (h0 + h1)

# try to preserve shape
mask = jnp.sign(d) != jnp.sign(m0)
mask2 = (jnp.sign(m0) != jnp.sign(m1)) & (jnp.abs(d) > 3.0 * jnp.abs(m0))
mmm = (~mask) & mask2

d = jnp.where(mask, 0.0, d)
d = jnp.where(mmm, 3.0 * m0, d)
return d

hk = 1 / hki
d0 = _edge_case(hk[0, :], hk[1, :], mk[0, :], mk[1, :])[None]
d1 = _edge_case(hk[-1, :], hk[-2, :], mk[-1, :], mk[-2, :])[None]

dk = jnp.concatenate([d0, dk, d1])
dk = dk.reshape(fshp)
return dk.reshape(fshp)
Loading

0 comments on commit 64b6b34

Please sign in to comment.