Skip to content

Commit

Permalink
Merge pull request #12 from allen-adastra/allenw/interp_scalars
Browse files Browse the repository at this point in the history
Allenw/interp scalars
  • Loading branch information
f0uriest authored Nov 18, 2023
2 parents c3b1d4c + 4bd3f51 commit 4160329
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
15 changes: 15 additions & 0 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ def interp1d(
fx = kwargs.pop("fx", None)
outshape = xq.shape + f.shape[1:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq = jnp.atleast_1d(xq)

errorif(
(len(x) != f.shape[axis]) or (jnp.ndim(x) != 1),
ValueError,
Expand Down Expand Up @@ -621,6 +626,11 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
xq, yq = jnp.broadcast_arrays(xq, yq)
outshape = xq.shape + f.shape[2:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq, yq = map(jnp.atleast_1d, (xq, yq))

errorif(
(len(x) != f.shape[0]) or (x.ndim != 1),
ValueError,
Expand Down Expand Up @@ -839,6 +849,11 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
xq, yq, zq = jnp.broadcast_arrays(xq, yq, zq)
outshape = xq.shape + f.shape[3:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq, yq, zq = map(jnp.atleast_1d, (xq, yq, zq))

fx = kwargs.pop("fx", None)
fy = kwargs.pop("fy", None)
fz = kwargs.pop("fz", None)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
equinox
jax[cpu] >= 0.3.2, <= 0.4.14
jax[cpu] >= 0.3.2, <= 0.4.20
numpy >= 1.20.0, < 1.25.0
scipy >= 1.5.0, < 1.11.0
39 changes: 29 additions & 10 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.config import config as jax_config
from jax import config as jax_config

from interpax import (
Interpolator1D,
Expand All @@ -24,10 +24,16 @@ class TestInterp1D:
"""Tests for interp1d function."""

@pytest.mark.unit
def test_interp1d(self):
@pytest.mark.parametrize(
"x",
[
np.linspace(0, 2 * np.pi, 10000),
0.0,
],
)
def test_interp1d(self, x):
"""Test accuracy of different 1d interpolation methods."""
xp = np.linspace(0, 2 * np.pi, 100)
x = np.linspace(0, 2 * np.pi, 10000)
f = lambda x: np.sin(x)
fp = f(xp)

Expand Down Expand Up @@ -99,12 +105,17 @@ class TestInterp2D:
"""Tests for interp2d function."""

@pytest.mark.unit
def test_interp2d(self):
@pytest.mark.parametrize(
"x, y",
[
(np.linspace(0, 3 * np.pi, 1000), np.linspace(0, 2 * np.pi, 1000)),
(0.0, 0.0),
],
)
def test_interp2d(self, x, y):
"""Test accuracy of different 2d interpolation methods."""
xp = np.linspace(0, 3 * np.pi, 99)
yp = np.linspace(0, 2 * np.pi, 40)
x = np.linspace(0, 3 * np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
xxp, yyp = np.meshgrid(xp, yp, indexing="ij")

f = lambda x, y: np.sin(x) * np.cos(y)
Expand Down Expand Up @@ -150,14 +161,22 @@ class TestInterp3D:
"""Tests for interp3d function."""

@pytest.mark.unit
def test_interp3d(self):
@pytest.mark.parametrize(
"x, y, z",
[
(
np.linspace(0, np.pi, 1000),
np.linspace(0, 2 * np.pi, 1000),
np.linspace(0, 3, 1000),
),
(0.0, 0.0, 0.0),
],
)
def test_interp3d(self, x, y, z):
"""Test accuracy of different 3d interpolation methods."""
xp = np.linspace(0, np.pi, 20)
yp = np.linspace(0, 2 * np.pi, 30)
zp = np.linspace(0, 3, 25)
x = np.linspace(0, np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
z = np.linspace(0, 3, 1000)
xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij")

f = lambda x, y, z: np.sin(x) * np.cos(y) * z**2
Expand Down

0 comments on commit 4160329

Please sign in to comment.