diff --git a/interpax/__init__.py b/interpax/__init__.py index b650dbf..e006dcb 100644 --- a/interpax/__init__.py +++ b/interpax/__init__.py @@ -9,6 +9,7 @@ CubicSpline, PchipInterpolator, PPoly, + SmoothingSpline, ) from ._spline import ( Interpolator1D, diff --git a/interpax/_ppoly.py b/interpax/_ppoly.py index 0654053..6586815 100644 --- a/interpax/_ppoly.py +++ b/interpax/_ppoly.py @@ -803,3 +803,130 @@ def __init__( x, _, y, axis, _ = prepare_input(x, y, axis, check=check) df = approx_df(x, y, "cubic2", axis, bc_type=bc_type) super().__init__(x, y, df, axis=axis, extrapolate=extrapolate, check=check) + + +class SmoothingSpline(CubicSpline): + """Smoothing spline for noisy data. + + The spline f minimizes + + p ∑ᵢ wᵢ ||yᵢ − f(xᵢ)||² + (1−p) ∫ₓ ||f''(x)||² + + Parameters + ---------- + x : array_like, shape (n,) + 1-D array containing values of the independent variable. + Values must be real, finite and in strictly increasing order. + y : array_like, shape(n,...) + Array containing values of the dependent variable. It can have + arbitrary number of dimensions, but the length along ``axis`` + (see below) must match the length of ``x``. Values must be finite. + p : float, optional + Smoothing parameter in the range [0,1]. + + - For ``p=0`` the spline is a straight line fit to the data. + - For ``p=1``, it is the cubic spline interpolant. + + If not given, ``p`` is determined automatically given the data sites. The + calculation of the smoothing spline requires the solution of a linear system + whose coefficient matrix has the form pA + (1-p)B, with the matrices A and B + depending on the data sites. The automatically computed smoothing parameter + makes p*trace(A) equal (1 - p)*trace(B). + w : array_like, shape(n,) + Weights for spline fitting. Must be positive. If None, then weights are all + equal. Default is None. + axis : int, optional + Axis along which `y` is assumed to be varying. Meaning that for + ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``. + Default is 0. + bc_type : string or 2-tuple, optional + Boundary condition type. Two additional equations, given by the + boundary conditions, are required to determine all coefficients of + polynomials on each segment [2]_. + + If `bc_type` is a string, then the specified condition will be applied + at both ends of a spline. Available conditions are: + + * 'not-a-knot' (default): The first and second segment at a curve end + are the same polynomial. It is a good default when there is no + information on boundary conditions. + * 'periodic': The interpolated functions is assumed to be periodic + of period ``x[-1] - x[0]``. The first and last value of `y` must be + identical: ``y[0] == y[-1]``. This boundary condition will result in + ``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``. + * 'clamped': The first derivative at curves ends are zero. Assuming + a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition. + * 'natural': The second derivative at curve ends are zero. Assuming + a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition. + + If `bc_type` is a 2-tuple, the first and the second value will be + applied at the curve start and end respectively. The tuple values can + be one of the previously mentioned strings (except 'periodic') or a + tuple `(order, deriv_values)` allowing to specify arbitrary + derivatives at curve ends: + + * `order`: the derivative order, 1 or 2. + * `deriv_value`: array_like containing derivative values, shape must + be the same as `y`, excluding ``axis`` dimension. For example, if + `y` is 1-D, then `deriv_value` must be a scalar. If `y` is 3-D with + the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2-D + and have the shape (n0, n1). + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. If 'periodic', + periodic extrapolation is used. If None (default), ``extrapolate`` is + set to 'periodic' for ``bc_type='periodic'`` and to True otherwise. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + """ + + def __init__( + self, + x: jax.Array, + y: jax.Array, + p: float = None, + w: jax.Array = None, + axis: int = 0, + bc_type: Union[str, tuple] = "natural", + extrapolate: Union[bool, str] = None, + check: bool = True, + ): + if w is None: + w = jnp.ones_like(x) + + dx = jnp.diff((x[1:] + x[:-1]) / 2, prepend=x[0], append=x[-1]) + + @jax.jit + def loss(f): + g = CubicSpline( + x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=False + ) + return jnp.concatenate([g(x), g(x, nu=2)]) + + AB = jax.jit(jax.jacfwd(loss))(jnp.zeros_like(x)) + A = AB[: AB.shape[1]] + B = AB[AB.shape[1] :] + A = jnp.sqrt(w)[:, None] * A + B = jnp.sqrt(dx)[:, None] * B + + # normalize smoothing parameter + span = jnp.ptp(x) + eff_x = 1 + (span**2) / jnp.sum(jnp.diff(x) ** 2) + eff_w = jnp.sum(w) ** 2 / jnp.sum(w**2) + k = 80 * (span**3) * (x.size**-2) * (eff_x**-0.5) * (eff_w**-0.5) + s = 0.5 if p is None else p + p = s / (s + (1 - s) * k) + + # p w ||Af - y||_2 + (1-p) ||Bf||_2 + # ||sqrt(p w)(Af - y)||_2 + ||sqrt(1-p)Bf||_2 + + lhs = jnp.vstack([jnp.sqrt(p) * A, jnp.sqrt(1 - p) * B]) + + y = jnp.moveaxis(y, axis, 0) + rhs = jnp.concatenate([jnp.sqrt(p * w) * y, jnp.zeros_like(y)], axis=0) + f = jnp.linalg.lstsq(lhs, rhs, rcond=None)[0] + f = jnp.moveaxis(f, 0, axis) + super().__init__( + x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=check + )