diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19e75a7..a7f36a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 24.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/interpax/__init__.py b/interpax/__init__.py index b650dbf..e006dcb 100644 --- a/interpax/__init__.py +++ b/interpax/__init__.py @@ -9,6 +9,7 @@ CubicSpline, PchipInterpolator, PPoly, + SmoothingSpline, ) from ._spline import ( Interpolator1D, diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py index 896daeb..0f120e1 100644 --- a/interpax/_fd_derivs.py +++ b/interpax/_fd_derivs.py @@ -40,6 +40,8 @@ def approx_df( First derivative of f with respect to x. """ + # noqa:D202 + # close over static args to deal with non-jittable kwargs def fun(x, f): return _approx_df(x, f, method, axis, **kwargs) diff --git a/interpax/_ppoly.py b/interpax/_ppoly.py index 0654053..6586815 100644 --- a/interpax/_ppoly.py +++ b/interpax/_ppoly.py @@ -803,3 +803,130 @@ def __init__( x, _, y, axis, _ = prepare_input(x, y, axis, check=check) df = approx_df(x, y, "cubic2", axis, bc_type=bc_type) super().__init__(x, y, df, axis=axis, extrapolate=extrapolate, check=check) + + +class SmoothingSpline(CubicSpline): + """Smoothing spline for noisy data. + + The spline f minimizes + + p ∑ᵢ wᵢ ||yᵢ − f(xᵢ)||² + (1−p) ∫ₓ ||f''(x)||² + + Parameters + ---------- + x : array_like, shape (n,) + 1-D array containing values of the independent variable. + Values must be real, finite and in strictly increasing order. + y : array_like, shape(n,...) + Array containing values of the dependent variable. It can have + arbitrary number of dimensions, but the length along ``axis`` + (see below) must match the length of ``x``. Values must be finite. + p : float, optional + Smoothing parameter in the range [0,1]. + + - For ``p=0`` the spline is a straight line fit to the data. + - For ``p=1``, it is the cubic spline interpolant. + + If not given, ``p`` is determined automatically given the data sites. The + calculation of the smoothing spline requires the solution of a linear system + whose coefficient matrix has the form pA + (1-p)B, with the matrices A and B + depending on the data sites. The automatically computed smoothing parameter + makes p*trace(A) equal (1 - p)*trace(B). + w : array_like, shape(n,) + Weights for spline fitting. Must be positive. If None, then weights are all + equal. Default is None. + axis : int, optional + Axis along which `y` is assumed to be varying. Meaning that for + ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``. + Default is 0. + bc_type : string or 2-tuple, optional + Boundary condition type. Two additional equations, given by the + boundary conditions, are required to determine all coefficients of + polynomials on each segment [2]_. + + If `bc_type` is a string, then the specified condition will be applied + at both ends of a spline. Available conditions are: + + * 'not-a-knot' (default): The first and second segment at a curve end + are the same polynomial. It is a good default when there is no + information on boundary conditions. + * 'periodic': The interpolated functions is assumed to be periodic + of period ``x[-1] - x[0]``. The first and last value of `y` must be + identical: ``y[0] == y[-1]``. This boundary condition will result in + ``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``. + * 'clamped': The first derivative at curves ends are zero. Assuming + a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition. + * 'natural': The second derivative at curve ends are zero. Assuming + a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition. + + If `bc_type` is a 2-tuple, the first and the second value will be + applied at the curve start and end respectively. The tuple values can + be one of the previously mentioned strings (except 'periodic') or a + tuple `(order, deriv_values)` allowing to specify arbitrary + derivatives at curve ends: + + * `order`: the derivative order, 1 or 2. + * `deriv_value`: array_like containing derivative values, shape must + be the same as `y`, excluding ``axis`` dimension. For example, if + `y` is 1-D, then `deriv_value` must be a scalar. If `y` is 3-D with + the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2-D + and have the shape (n0, n1). + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. If 'periodic', + periodic extrapolation is used. If None (default), ``extrapolate`` is + set to 'periodic' for ``bc_type='periodic'`` and to True otherwise. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + """ + + def __init__( + self, + x: jax.Array, + y: jax.Array, + p: float = None, + w: jax.Array = None, + axis: int = 0, + bc_type: Union[str, tuple] = "natural", + extrapolate: Union[bool, str] = None, + check: bool = True, + ): + if w is None: + w = jnp.ones_like(x) + + dx = jnp.diff((x[1:] + x[:-1]) / 2, prepend=x[0], append=x[-1]) + + @jax.jit + def loss(f): + g = CubicSpline( + x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=False + ) + return jnp.concatenate([g(x), g(x, nu=2)]) + + AB = jax.jit(jax.jacfwd(loss))(jnp.zeros_like(x)) + A = AB[: AB.shape[1]] + B = AB[AB.shape[1] :] + A = jnp.sqrt(w)[:, None] * A + B = jnp.sqrt(dx)[:, None] * B + + # normalize smoothing parameter + span = jnp.ptp(x) + eff_x = 1 + (span**2) / jnp.sum(jnp.diff(x) ** 2) + eff_w = jnp.sum(w) ** 2 / jnp.sum(w**2) + k = 80 * (span**3) * (x.size**-2) * (eff_x**-0.5) * (eff_w**-0.5) + s = 0.5 if p is None else p + p = s / (s + (1 - s) * k) + + # p w ||Af - y||_2 + (1-p) ||Bf||_2 + # ||sqrt(p w)(Af - y)||_2 + ||sqrt(1-p)Bf||_2 + + lhs = jnp.vstack([jnp.sqrt(p) * A, jnp.sqrt(1 - p) * B]) + + y = jnp.moveaxis(y, axis, 0) + rhs = jnp.concatenate([jnp.sqrt(p * w) * y, jnp.zeros_like(y)], axis=0) + f = jnp.linalg.lstsq(lhs, rhs, rcond=None)[0] + f = jnp.moveaxis(f, 0, axis) + super().__init__( + x, f, axis=axis, bc_type=bc_type, extrapolate=extrapolate, check=check + ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 99098ad..b4bac3b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,7 +12,7 @@ sphinx-github-style >= 1.0, <= 1.1 # linting -black == 22.10.0 +black == 24.3.0 flake8 >= 5.0.0, <=6.0.0 flake8-docstrings >= 1.0.0, <=2.0.0 flake8-eradicate >= 1.0.0, <=2.0.0 diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index f3c1441..f2760a6 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -286,9 +286,7 @@ def test_interp3d_vector_valued(self): zp = np.linspace(0, 3, 25) xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij") - f = lambda x, y, z: np.array( - [np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)] - ) + f = lambda x, y, z: np.array([np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)]) fp = f(xxp.T, yyp.T, zzp.T).T fq = interp3d(x, y, z, xp, yp, zp, fp, method="nearest")