Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add smoothing spline #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
1 change: 1 addition & 0 deletions interpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CubicSpline,
PchipInterpolator,
PPoly,
SmoothingSpline,
)
from ._spline import (
Interpolator1D,
Expand Down
2 changes: 2 additions & 0 deletions interpax/_fd_derivs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def approx_df(
First derivative of f with respect to x.

"""
# noqa:D202

# close over static args to deal with non-jittable kwargs
def fun(x, f):
return _approx_df(x, f, method, axis, **kwargs)
Expand Down
127 changes: 127 additions & 0 deletions interpax/_ppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,130 @@ def __init__(
x, _, y, axis, _ = prepare_input(x, y, axis, check=check)
df = approx_df(x, y, "cubic2", axis, bc_type=bc_type)
super().__init__(x, y, df, axis=axis, extrapolate=extrapolate, check=check)


class SmoothingSpline(CubicSpline):
"""Smoothing spline for noisy data.

The spline f minimizes

p ∑ᵢ wᵢ ||yᵢ − f(xᵢ)||² + (1−p) ∫ₓ ||f''(x)||²

Parameters
----------
x : array_like, shape (n,)
1-D array containing values of the independent variable.
Values must be real, finite and in strictly increasing order.
y : array_like, shape(n,...)
Array containing values of the dependent variable. It can have
arbitrary number of dimensions, but the length along ``axis``
(see below) must match the length of ``x``. Values must be finite.
p : float, optional
Smoothing parameter in the range [0,1].

- For ``p=0`` the spline is a straight line fit to the data.
- For ``p=1``, it is the cubic spline interpolant.

If not given, ``p`` is determined automatically given the data sites. The
calculation of the smoothing spline requires the solution of a linear system
whose coefficient matrix has the form pA + (1-p)B, with the matrices A and B
depending on the data sites. The automatically computed smoothing parameter
makes p*trace(A) equal (1 - p)*trace(B).
w : array_like, shape(n,)
Weights for spline fitting. Must be positive. If None, then weights are all
equal. Default is None.
axis : int, optional
Axis along which `y` is assumed to be varying. Meaning that for
``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``.
Default is 0.
bc_type : string or 2-tuple, optional
Boundary condition type. Two additional equations, given by the
boundary conditions, are required to determine all coefficients of
polynomials on each segment [2]_.

If `bc_type` is a string, then the specified condition will be applied
at both ends of a spline. Available conditions are:

* 'not-a-knot' (default): The first and second segment at a curve end
are the same polynomial. It is a good default when there is no
information on boundary conditions.
* 'periodic': The interpolated functions is assumed to be periodic
of period ``x[-1] - x[0]``. The first and last value of `y` must be
identical: ``y[0] == y[-1]``. This boundary condition will result in
``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``.
* 'clamped': The first derivative at curves ends are zero. Assuming
a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition.
* 'natural': The second derivative at curve ends are zero. Assuming
a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition.

If `bc_type` is a 2-tuple, the first and the second value will be
applied at the curve start and end respectively. The tuple values can
be one of the previously mentioned strings (except 'periodic') or a
tuple `(order, deriv_values)` allowing to specify arbitrary
derivatives at curve ends:

* `order`: the derivative order, 1 or 2.
* `deriv_value`: array_like containing derivative values, shape must
be the same as `y`, excluding ``axis`` dimension. For example, if
`y` is 1-D, then `deriv_value` must be a scalar. If `y` is 3-D with
the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2-D
and have the shape (n0, n1).
extrapolate : {bool, 'periodic', None}, optional
If bool, determines whether to extrapolate to out-of-bounds points
based on first and last intervals, or to return NaNs. If 'periodic',
periodic extrapolation is used. If None (default), ``extrapolate`` is
set to 'periodic' for ``bc_type='periodic'`` and to True otherwise.
check : bool
Whether to perform checks on the input. Should be False if used under JIT.

"""

def __init__(
self,
x: jax.Array,
y: jax.Array,
p: float = None,
w: jax.Array = None,
axis: int = 0,
bc_type: Union[str, tuple] = "natural",
extrapolate: Union[bool, str] = None,
check: bool = True,
):
if w is None:
w = jnp.ones_like(x)

dx = jnp.diff((x[1:] + x[:-1]) / 2, prepend=x[0], append=x[-1])

@jax.jit
def loss(f):
g = CubicSpline(
x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=False
)
return jnp.concatenate([g(x), g(x, nu=2)])

AB = jax.jit(jax.jacfwd(loss))(jnp.zeros_like(x))
A = AB[: AB.shape[1]]
B = AB[AB.shape[1] :]
A = jnp.sqrt(w)[:, None] * A
B = jnp.sqrt(dx)[:, None] * B

# normalize smoothing parameter
span = jnp.ptp(x)
eff_x = 1 + (span**2) / jnp.sum(jnp.diff(x) ** 2)
eff_w = jnp.sum(w) ** 2 / jnp.sum(w**2)
k = 80 * (span**3) * (x.size**-2) * (eff_x**-0.5) * (eff_w**-0.5)
s = 0.5 if p is None else p
p = s / (s + (1 - s) * k)

# p w ||Af - y||_2 + (1-p) ||Bf||_2
# ||sqrt(p w)(Af - y)||_2 + ||sqrt(1-p)Bf||_2

lhs = jnp.vstack([jnp.sqrt(p) * A, jnp.sqrt(1 - p) * B])

y = jnp.moveaxis(y, axis, 0)
rhs = jnp.concatenate([jnp.sqrt(p * w) * y, jnp.zeros_like(y)], axis=0)
f = jnp.linalg.lstsq(lhs, rhs, rcond=None)[0]
f = jnp.moveaxis(f, 0, axis)
super().__init__(
x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=check
)
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ sphinx-github-style >= 1.0, <= 1.1


# linting
black == 22.10.0
black == 24.3.0
flake8 >= 5.0.0, <=6.0.0
flake8-docstrings >= 1.0.0, <=2.0.0
flake8-eradicate >= 1.0.0, <=2.0.0
Expand Down
4 changes: 1 addition & 3 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,7 @@ def test_interp3d_vector_valued(self):
zp = np.linspace(0, 3, 25)
xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij")

f = lambda x, y, z: np.array(
[np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)]
)
f = lambda x, y, z: np.array([np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)])
fp = f(xxp.T, yyp.T, zzp.T).T

fq = interp3d(x, y, z, xp, yp, zp, fp, method="nearest")
Expand Down
Loading