Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scipy.interpolate API #25

Merged
merged 5 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
Changelog
=========

- Adds a number of classes that replicate most of the functionality of the
corresponding classes from scipy.interpolate :
- ``scipy.interpolate.PPoly`` -> ``interpax.PPoly``
- ``scipy.interpolate.Akima1DInterpolator`` -> ``interpax.Akima1DInterpolator``
- ``scipy.interpolate.CubicHermiteSpline`` -> ``interpax.CubicHermiteSpline``
- ``scipy.interpolate.CubicSpline`` -> ``interpax.CubicSpline``
- ``scipy.interpolate.PchipInterpolator`` -> ``interpax.PchipInterpolator``
- Method ``"akima"`` now available for ``Interpolator.{1D, 2D, 3D}`` and corresponding
functions.
- Method ``"monotonic"`` now works in 2D and 3D, where it will preserve monotonicity
with respect to each coordinate individually.


v0.2.4
------
- Fixes for scalar valued query points
Expand Down
5 changes: 5 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ BUILDDIR = _build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


clean:
rm -rf _api/
rm -rf _build/

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
Expand Down
5 changes: 3 additions & 2 deletions docs/_templates/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
.. autosummary::
:toctree: {{ objname }}

{% for item in methods %}
{% if item != "__init__" %}

{% for item in all_methods %}
{%- if not item.startswith('_') or item in ['__call__',] %}
~{{ name }}.{{ item }}
{% endif %}
{%- endfor %}
Expand Down
81 changes: 54 additions & 27 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,65 @@
API Documentation
=================

interp1d
********
.. autofunction:: interpax.interp1d
Interpolation of 1D, 2D, or 3D data
-----------------------------------

interp2d
********
.. autofunction:: interpax.interp2d
.. autosummary::
:toctree: _api/
:recursive:
:template: class.rst

interp3d
********
.. autofunction:: interpax.interp3d
interpax.Interpolator1D
interpax.Interpolator2D
interpax.Interpolator3D

fft_interp1d
************
.. autofunction:: interpax.fft_interp1d

fft_interp2d
************
.. autofunction:: interpax.fft_interp2d
``scipy.interpolate``-like classes
----------------------------------

approx_df
*********
.. autofunction:: interpax.approx_df
These classes implement most of the functionality of the SciPy classes with the same names,
except where noted in the documentation.

Interpolator1D
**************
.. autoclass:: interpax.Interpolator1D
.. autosummary::
:toctree: _api/
:recursive:
:template: class.rst

Interpolator2D
**************
.. autoclass:: interpax.Interpolator2D
interpax.Akima1DInterpolator
interpax.CubicHermiteSpline
interpax.CubicSpline
interpax.PchipInterpolator
interpax.PPoly

Interpolator3D
**************
.. autoclass:: interpax.Interpolator3D

Functional interface for 1D, 2D, 3D interpolation
-------------------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

interpax.interp1d
interpax.interp2d
interpax.interp2d


Fourier interpolation of periodic functions in 1D and 2D
--------------------------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

interpax.fft_interp1d
interpax.fft_interp2d


Approximating first derivatives for cubic splines
-------------------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

interpax.approx_df
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def linkcode_resolve(domain, info):

autodoc_default_options = {
"member-order": "bysource",
"special-members": "__call__",
"exclude-members": "__init__",
}
# Add any paths that contain templates here, relative to this directory.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: Public API

api
Expand Down
7 changes: 7 additions & 0 deletions interpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from . import _version
from ._fd_derivs import approx_df
from ._fourier import fft_interp1d, fft_interp2d
from ._ppoly import (
Akima1DInterpolator,
CubicHermiteSpline,
CubicSpline,
PchipInterpolator,
PPoly,
)
from ._spline import (
Interpolator1D,
Interpolator2D,
Expand Down
35 changes: 19 additions & 16 deletions interpax/_fd_derivs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit

from .utils import errorif
from .utils import asarray_inexact, errorif


def approx_df(
Expand Down Expand Up @@ -42,10 +40,13 @@ def approx_df(
First derivative of f with respect to x.

"""
return _approx_df(x, f, method, axis, **kwargs)
# close over static args to deal with non-jittable kwargs
def fun(x, f):
return _approx_df(x, f, method, axis, **kwargs)

return jit(fun)(x, f)


@partial(jit, static_argnames=("method", "axis", "bc_type"))
def _approx_df(x, f, method, axis, c=0, bc_type="not-a-knot"):
if method == "cubic":
out = _cubic1(x, f, axis)
Expand Down Expand Up @@ -92,7 +93,7 @@ def _cubic1(x, f, axis):
return fx


def _validate_bc(bc_type, expected_deriv_shape):
def _validate_bc(bc_type, expected_deriv_shape, dtype):
if isinstance(bc_type, str):
errorif(bc_type == "periodic", NotImplementedError)
bc_type = (bc_type, bc_type)
Expand Down Expand Up @@ -136,20 +137,21 @@ def _validate_bc(bc_type, expected_deriv_shape):
if deriv_order not in [1, 2]:
raise ValueError("The specified derivative order must " "be 1 or 2.")

deriv_value = jnp.asarray(deriv_value)
deriv_value = asarray_inexact(deriv_value)
dtype = jnp.promote_types(dtype, deriv_value.dtype)
if deriv_value.shape != expected_deriv_shape:
raise ValueError(
"`deriv_value` shape {} is not the expected one {}.".format(
deriv_value.shape, expected_deriv_shape
)
)
validated_bc.append((deriv_order, deriv_value))
return validated_bc
return validated_bc, dtype


def _cubic2(x, f, axis, bc_type):
f = jnp.moveaxis(f, axis, 0)
bc = _validate_bc(bc_type, f.shape[1:])
bc, dtype = _validate_bc(bc_type, f.shape[1:], f.dtype)
dx = jnp.diff(x)
df = jnp.diff(f, axis=0)
dxr = dx.reshape([dx.shape[0]] + [1] * (f.ndim - 1))
Expand All @@ -173,7 +175,7 @@ def _cubic2(x, f, axis, bc_type):
# constructing a parabola passing through given points.
if n == 3 and bc[0] == "not-a-knot" and bc[1] == "not-a-knot":
A = jnp.zeros((3, 3)) # This is a standard matrix.
b = jnp.empty((3,) + f.shape[1:], dtype=f.dtype)
b = jnp.empty((3,) + f.shape[1:], dtype=dtype)

A = A.at[0, 0].set(1)
A = A.at[0, 1].set(1)
Expand All @@ -187,20 +189,21 @@ def _cubic2(x, f, axis, bc_type):
b = b.at[1].set(3 * (dxr[0] * df[1] + dxr[1] * df[0]))
b = b.at[2].set(2 * df[1])

s = jnp.linalg.solve(A, b)
fx = jnp.moveaxis(s, 0, axis)
solve = lambda b: jnp.linalg.solve(A, b)
fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
fx = jnp.moveaxis(fx, 0, axis)

else:

# Find derivative values at each x[i] by solving a tridiagonal
# system.
diag = jnp.zeros(n)
diag = jnp.zeros(n, dtype=x.dtype)
diag = diag.at[1:-1].set(2 * (dx[:-1] + dx[1:]))
upper_diag = jnp.zeros(n - 1)
upper_diag = jnp.zeros(n - 1, dtype=x.dtype)
upper_diag = upper_diag.at[1:].set(dx[:-1])
lower_diag = jnp.zeros(n - 1)
lower_diag = jnp.zeros(n - 1, dtype=x.dtype)
lower_diag = lower_diag.at[:-1].set(dx[1:])
b = jnp.zeros((n,) + f.shape[1:], dtype=f.dtype)
b = jnp.zeros((n,) + f.shape[1:], dtype=dtype)
b = b.at[1:-1].set(3 * (dxr[1:] * df[:-1] + dxr[:-1] * df[1:]))

bc_start, bc_end = bc
Expand Down
Loading
Loading