Skip to content

Commit

Permalink
Add typing info
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Nov 7, 2023
1 parent 689a4e5 commit 169a2ae
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 41 deletions.
15 changes: 12 additions & 3 deletions interpax/_fourier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit


@partial(jit, static_argnames="n")
def fft_interp1d(f, n, sx=None, dx=1):
def fft_interp1d(f: jax.Array, n: int, sx: jax.Array = None, dx: float = 1.0):
"""Interpolation of a 1d periodic function via FFT.
Parameters
Expand Down Expand Up @@ -38,7 +39,15 @@ def fft_interp1d(f, n, sx=None, dx=1):


@partial(jit, static_argnames=("n1", "n2"))
def fft_interp2d(f, n1, n2, sx=None, sy=None, dx=1, dy=1):
def fft_interp2d(
f: jax.Array,
n1: int,
n2: int,
sx: jax.Array = None,
sy: jax.Array = None,
dx: float = 1.0,
dy: float = 1.0,
):
"""Interpolation of a 2d periodic function via FFT.
Parameters
Expand Down Expand Up @@ -82,7 +91,7 @@ def fft_interp2d(f, n1, n2, sx=None, sy=None, dx=1, dy=1):
return jnp.fft.fft2(c, axes=(0, 1)).real


def _pad_along_axis(array, pad=(0, 0), axis=0):
def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0):
"""Pad with zeros or truncate a given dimension."""
array = jnp.moveaxis(array, axis, 0)

Expand Down
127 changes: 89 additions & 38 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import OrderedDict
from functools import partial
from typing import Union

import equinox as eqx
import jax
Expand Down Expand Up @@ -62,11 +63,19 @@ class Interpolator1D(eqx.Module):
f: jax.Array
derivs: dict
method: str
extrap: bool | float | tuple
period: float | tuple
extrap: Union[bool, float, tuple]
period: Union[None, float]
axis: int

def __init__(self, x, f, method="cubic", extrap=False, period=None, **kwargs):
def __init__(
self,
x: jax.Array,
f: jax.Array,
method: str = "cubic",
extrap: Union[bool, float, tuple] = False,
period: Union[None, float] = None,
**kwargs,
):
x, f = map(jnp.asarray, (x, f))
axis = kwargs.get("axis", 0)
fx = kwargs.pop("fx", None)
Expand All @@ -90,7 +99,7 @@ def __init__(self, x, f, method="cubic", extrap=False, period=None, **kwargs):

self.derivs = {"fx": fx}

def __call__(self, xq, dx=0):
def __call__(self, xq: jax.Array, dx: int = 0):
"""Evaluate the interpolated function or its derivatives.
Parameters
Expand Down Expand Up @@ -161,11 +170,20 @@ class Interpolator2D(eqx.Module):
f: jax.Array
derivs: dict
method: str
extrap: bool | float | tuple
period: float | tuple
extrap: Union[bool, float, tuple]
period: Union[None, float, tuple]
axis: int

def __init__(self, x, y, f, method="cubic", extrap=False, period=None, **kwargs):
def __init__(
self,
x: jax.Array,
y: jax.Array,
f: jax.Array,
method: str = "cubic",
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
**kwargs,
):
x, y, f = map(jnp.asarray, (x, y, f))
axis = kwargs.get("axis", 0)
fx = kwargs.pop("fx", None)
Expand Down Expand Up @@ -201,7 +219,7 @@ def __init__(self, x, y, f, method="cubic", extrap=False, period=None, **kwargs)

self.derivs = {"fx": fx, "fy": fy, "fxy": fxy}

def __call__(self, xq, yq, dx=0, dy=0):
def __call__(self, xq: jax.Array, yq: jax.Array, dx: int = 0, dy: int = 0):
"""Evaluate the interpolated function or its derivatives.
Parameters
Expand Down Expand Up @@ -260,7 +278,7 @@ class Interpolator3D(eqx.Module):
also be passed as an array or tuple to specify different conditions
[[xlow, xhigh],[ylow,yhigh]]
period : float > 0, None, array-like, shape(2,)
periodicity of the function in x, y directions. None denotes no periodicity,
periodicity of the function in x, y, z directions. None denotes no periodicity,
otherwise function is assumed to be periodic on the interval [0,period]. Use a
single value for the same in both directions.
Expand All @@ -277,11 +295,21 @@ class Interpolator3D(eqx.Module):
f: jax.Array
derivs: dict
method: str
extrap: bool | float | tuple
period: float | tuple
extrap: Union[bool, float, tuple]
period: Union[None, float, tuple]
axis: int

def __init__(self, x, y, z, f, method="cubic", extrap=False, period=None, **kwargs):
def __init__(
self,
x: jax.Array,
y: jax.Array,
z: jax.Array,
f: jax.Array,
method: str = "cubic",
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
**kwargs,
):
x, y, z, f = map(jnp.asarray, (x, y, z, f))
axis = kwargs.get("axis", 0)

Expand Down Expand Up @@ -344,7 +372,15 @@ def __init__(self, x, y, z, f, method="cubic", extrap=False, period=None, **kwar
"fxyz": fxyz,
}

def __call__(self, xq, yq, zq, dx=0, dy=0, dz=0):
def __call__(
self,
xq: jax.Array,
yq: jax.Array,
zq: jax.Array,
dx: int = 0,
dy: int = 0,
dz: int = 0,
):
"""Evaluate the interpolated function or its derivatives.
Parameters
Expand Down Expand Up @@ -377,7 +413,14 @@ def __call__(self, xq, yq, zq, dx=0, dy=0, dz=0):

@partial(jit, static_argnames="method")
def interp1d(
xq, x, f, method="cubic", derivative=0, extrap=False, period=None, **kwargs
xq: jax.Array,
x: jax.Array,
f: jax.Array,
method: str = "cubic",
derivative: int = 0,
extrap: Union[bool, float, tuple] = False,
period: Union[None, float] = None,
**kwargs,
):
"""Interpolate a 1d function.
Expand Down Expand Up @@ -510,15 +553,15 @@ def derivative2():

@partial(jit, static_argnames="method")
def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
xq,
yq,
x,
y,
f,
method="cubic",
derivative=0,
extrap=False,
period=None,
xq: jax.Array,
yq: jax.Array,
x: jax.Array,
y: jax.Array,
f: jax.Array,
method: str = "cubic",
derivative: int = 0,
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
**kwargs,
):
"""Interpolate a 2d function.
Expand Down Expand Up @@ -708,17 +751,17 @@ def derivative1():

@partial(jit, static_argnames="method")
def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
xq,
yq,
zq,
x,
y,
z,
f,
method="cubic",
derivative=0,
extrap=False,
period=None,
xq: jax.Array,
yq: jax.Array,
zq: jax.Array,
x: jax.Array,
y: jax.Array,
z: jax.Array,
f: jax.Array,
method: str = "cubic",
derivative: int = 0,
extrap: Union[bool, float, tuple] = False,
period: Union[None, float, tuple] = None,
**kwargs,
):
"""Interpolate a 3d function.
Expand Down Expand Up @@ -994,7 +1037,7 @@ def derivative1():


@partial(jit, static_argnames=("axis"))
def _make_periodic(xq, x, period, axis, *arrs):
def _make_periodic(xq: jax.Array, x: jax.Array, period: float, axis: int, *arrs):
"""Make arrays periodic along a specified axis."""
period = abs(period)
xq = xq % period
Expand All @@ -1018,7 +1061,7 @@ def _make_periodic(xq, x, period, axis, *arrs):


@jit
def _get_t_der(t, derivative, dxi):
def _get_t_der(t: jax.Array, derivative: int, dxi: jax.Array):
"""Get arrays of [1,t,t^2,t^3] for cubic interpolation."""
t0 = jnp.zeros_like(t)
t1 = jnp.ones_like(t)
Expand Down Expand Up @@ -1058,7 +1101,13 @@ def _parse_extrap(extrap, n):


@jit
def _extrap(xq, fq, x, lo, hi):
def _extrap(
xq: jax.Array,
fq: jax.Array,
x: jax.Array,
lo: Union[bool, float],
hi: Union[bool, float],
):
"""Clamp or extrapolate values outside bounds."""

def loclip(fq, lo):
Expand Down Expand Up @@ -1095,7 +1144,9 @@ def noclip(fq, *_):


@partial(jit, static_argnames=("method", "axis"))
def approx_df(x, f, method="cubic", axis=-1, **kwargs):
def approx_df(
x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs
):
"""Approximates first derivatives using cubic spline interpolation.
Parameters
Expand Down

0 comments on commit 169a2ae

Please sign in to comment.