Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for interpolating vector valued functions #13

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
Changelog
=========

v0.2.4
------
- Fixes for scalar valued query points
- Fixes for interpolating vector valued functions

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.3...v0.2.4


v0.2.3
------
- Add type annotations

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.2...v0.2.3


v0.2.2
------
- Add ``approx_df`` to public API

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.1...v0.2.2


v0.2.1
------
- More efficient nearest neighbor search
- Correct slopes for linear interpolation in 2d, 3d
- Fix for cubic2 splines in 2d and 3d
Forward and reverse mode AD now fully working and tested

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.0...v0.2.1


v0.2.0
-------
- Adds convenience classes for spline interpolation that cache the derivative calculation.

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.1.0...v0.2.0


v0.1.0
------
Expand Down
76 changes: 38 additions & 38 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,17 +513,17 @@ def derivative0():
delta = xq - x[i - 1]
fq = jnp.where(
(dx == 0),
jnp.take(f, i, axis),
jnp.take(f, i - 1, axis) + delta * dxi * df,
)
jnp.take(f, i, axis).T,
jnp.take(f, i - 1, axis).T + (delta * dxi * df.T),
).T
return fq

def derivative1():
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis)
dx = x[i] - x[i - 1]
dxi = jnp.where(dx == 0, 0, 1 / dx)
return df * dxi
return (df.T * dxi).T

def derivative2():
return jnp.zeros((xq.size, *f.shape[1:]))
Expand All @@ -544,11 +544,11 @@ def derivative2():

f0 = jnp.take(f, i - 1, axis)
f1 = jnp.take(f, i, axis)
fx0 = jnp.take(fx, i - 1, axis) * dx
fx1 = jnp.take(fx, i, axis) * dx
fx0 = (jnp.take(fx, i - 1, axis).T * dx).T
fx1 = (jnp.take(fx, i, axis).T * dx).T

F = jnp.vstack([f0, f1, fx0, fx1])
coef = jnp.matmul(A_CUBIC, F)
F = jnp.stack([f0, f1, fx0, fx1], axis=0).T
coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_CUBIC, F).T
ttx = _get_t_der(t, derivative, dxi)
fq = jnp.einsum("ji...,ij->i...", coef, ttx)

Expand Down Expand Up @@ -666,12 +666,12 @@ def derivative0():
[[x[i], x[i - 1], x[i], x[i - 1]], [y[j], y[j], y[j - 1], y[j - 1]]]
)
neighbors_f = jnp.array(
[f[i, j], f[i - 1, j], f[i, j - 1], f[i - 1, j - 1]]
[f[i, j].T, f[i - 1, j].T, f[i, j - 1].T, f[i - 1, j - 1].T]
)
xyq = jnp.array([xq, yq])
dist = jnp.linalg.norm(neighbors_x - xyq[:, None, :], axis=0)
idx = jnp.argmin(dist, axis=0)
return jax.vmap(jnp.take)(neighbors_f.T, idx)
return jax.vmap(lambda a, b: jnp.take(a, b, axis=-1))(neighbors_f.T, idx)

def derivative1():
return jnp.zeros((xq.size, *f.shape[2:]))
Expand Down Expand Up @@ -708,7 +708,7 @@ def derivative1():
tx = jax.lax.switch(derivative_x, [dx0, dx1, dx2])
ty = jax.lax.switch(derivative_y, [dy0, dy1, dy2])
F = jnp.array([[f00, f01], [f10, f11]])
fq = dxi * dyi * jnp.einsum("ijk...,ik,jk->k...", F, tx, ty)
fq = (dxi * dyi * jnp.einsum("ijk...,ik,jk->k...", F, tx, ty).T).T

elif method in CUBIC_METHODS:
if fx is None:
Expand Down Expand Up @@ -740,15 +740,16 @@ def derivative1():
for ff in fs.keys():
for jj in [0, 1]:
for ii in [0, 1]:
fsq[ff + str(ii) + str(jj)] = fs[ff][i - 1 + ii, j - 1 + jj]
s = ff + str(ii) + str(jj)
fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj]
if "x" in ff:
fsq[ff + str(ii) + str(jj)] *= dx
fsq[s] = (dx * fsq[s].T).T
if "y" in ff:
fsq[ff + str(ii) + str(jj)] *= dy
fsq[s] = (dy * fsq[s].T).T

F = jnp.vstack([foo for foo in fsq.values()])
coef = jnp.matmul(A_BICUBIC, F)
coef = jnp.moveaxis(coef.reshape((4, 4, -1), order="F"), -1, 0)
F = jnp.stack([foo for foo in fsq.values()], axis=0).T
coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_BICUBIC, F).T
coef = jnp.moveaxis(coef.reshape((4, 4, *coef.shape[1:]), order="F"), 2, 0)
ttx = _get_t_der(tx, derivative_x, dxi)
tty = _get_t_der(ty, derivative_y, dyi)
fq = jnp.einsum("ijk...,ij,ik->i...", coef, ttx, tty)
Expand Down Expand Up @@ -900,20 +901,20 @@ def derivative0():
)
neighbors_f = jnp.array(
[
f[i, j, k],
f[i - 1, j, k],
f[i, j - 1, k],
f[i - 1, j - 1, k],
f[i, j, k - 1],
f[i - 1, j, k - 1],
f[i, j - 1, k - 1],
f[i - 1, j - 1, k - 1],
f[i, j, k].T,
f[i - 1, j, k].T,
f[i, j - 1, k].T,
f[i - 1, j - 1, k].T,
f[i, j, k - 1].T,
f[i - 1, j, k - 1].T,
f[i, j - 1, k - 1].T,
f[i - 1, j - 1, k - 1].T,
]
)
xyzq = jnp.array([xq, yq, zq])
dist = jnp.linalg.norm(neighbors_x - xyzq[:, None, :], axis=0)
idx = jnp.argmin(dist, axis=0)
return jax.vmap(jnp.take)(neighbors_f.T, idx)
return jax.vmap(lambda a, b: jnp.take(a, b, axis=-1))(neighbors_f.T, idx)

def derivative1():
return jnp.zeros((xq.size, *f.shape[3:]))
Expand Down Expand Up @@ -966,7 +967,7 @@ def derivative1():
tz = jax.lax.switch(derivative_z, [dz0, dz1, dz2])

F = jnp.array([[[f000, f001], [f010, f011]], [[f100, f101], [f110, f111]]])
fq = dxi * dyi * dzi * jnp.einsum("lijk...,lk,ik,jk->k...", F, tx, ty, tz)
fq = (dxi * dyi * dzi * jnp.einsum("lijk...,lk,ik,jk->k...", F, tx, ty, tz).T).T

elif method in CUBIC_METHODS:
if fx is None:
Expand Down Expand Up @@ -1026,19 +1027,18 @@ def derivative1():
for kk in [0, 1]:
for jj in [0, 1]:
for ii in [0, 1]:
fsq[ff + str(ii) + str(jj) + str(kk)] = fs[ff][
i - 1 + ii, j - 1 + jj, k - 1 + kk
]
s = ff + str(ii) + str(jj) + str(kk)
fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj, k - 1 + kk]
if "x" in ff:
fsq[ff + str(ii) + str(jj) + str(kk)] *= dx
fsq[s] = (dx * fsq[s].T).T
if "y" in ff:
fsq[ff + str(ii) + str(jj) + str(kk)] *= dy
fsq[s] = (dy * fsq[s].T).T
if "z" in ff:
fsq[ff + str(ii) + str(jj) + str(kk)] *= dz
fsq[s] = (dz * fsq[s].T).T

F = jnp.vstack([foo for foo in fsq.values()])
coef = jnp.matmul(A_TRICUBIC, F)
coef = jnp.moveaxis(coef.reshape((4, 4, 4, -1), order="F"), -1, 0)
F = jnp.stack([foo for foo in fsq.values()], axis=0).T
coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_TRICUBIC, F).T
coef = jnp.moveaxis(coef.reshape((4, 4, 4, *coef.shape[1:]), order="F"), 3, 0)
ttx = _get_t_der(tx, derivative_x, dxi)
tty = _get_t_der(ty, derivative_y, dyi)
ttz = _get_t_der(tz, derivative_z, dzi)
Expand Down Expand Up @@ -1129,13 +1129,13 @@ def loclip(fq, lo):
# lo is either False (no extrapolation) or a fixed value to fill in
if isbool(lo):
lo = jnp.nan
return jnp.where(xq < x[0], lo, fq)
return jnp.where(xq < x[0], lo, fq.T).T

def hiclip(fq, hi):
# hi is either False (no extrapolation) or a fixed value to fill in
if isbool(hi):
hi = jnp.nan
return jnp.where(xq > x[-1], hi, fq)
return jnp.where(xq > x[-1], hi, fq.T).T

def noclip(fq, *_):
return fq
Expand Down
78 changes: 78 additions & 0 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,38 @@ def test_interp1d(self, x):
fq = interp(x, xp, fp, method="monotonic-0")
np.testing.assert_allclose(fq, f(x), rtol=1e-4, atol=1e-2)

@pytest.mark.unit
def test_interp1d_vector_valued(self):
"""Test for interpolating vector valued function."""
xp = np.linspace(0, 2 * np.pi, 100)
x = np.linspace(0, 2 * np.pi, 300)[10:-10]
f = lambda x: np.array([np.sin(x), np.cos(x)])
fp = f(xp).T

fq = interp1d(x, xp, fp, method="nearest")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-2, atol=1e-1)

fq = interp1d(x, xp, fp, method="linear")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3)

fq = interp1d(x, xp, fp, method="cubic")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="cubic2")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="cardinal")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="catmull-rom")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="monotonic")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3)

fq = interp1d(x, xp, fp, method="monotonic-0")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-2)

@pytest.mark.unit
def test_interp1d_extrap_periodic(self):
"""Test extrapolation and periodic BC of 1d interpolation."""
Expand Down Expand Up @@ -156,6 +188,27 @@ def test_interp2d(self, x, y):
)
np.testing.assert_allclose(fq, f(x, y), rtol=rtol, atol=atol)

@pytest.mark.unit
def test_interp2d_vector_valued(self):
"""Test for interpolating vector valued function."""
xp = np.linspace(0, 3 * np.pi, 99)
yp = np.linspace(0, 2 * np.pi, 40)
x = np.linspace(0, 3 * np.pi, 200)
y = np.linspace(0, 2 * np.pi, 200)
xxp, yyp = np.meshgrid(xp, yp, indexing="ij")

f = lambda x, y: np.array([np.sin(x) * np.cos(y), np.sin(x) + np.cos(y)])
fp = f(xxp.T, yyp.T).T

fq = interp2d(x, y, xp, yp, fp, method="nearest")
np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-2, atol=1.2e-1)

fq = interp2d(x, y, xp, yp, fp, method="linear")
np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-3, atol=1e-2)

fq = interp2d(x, y, xp, yp, fp, method="cubic")
np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-5, atol=2e-3)


class TestInterp3D:
"""Tests for interp3d function."""
Expand Down Expand Up @@ -213,6 +266,31 @@ def test_interp3d(self, x, y, z):
fq = interp(x, y, z, xp, yp, zp, fp, method="cardinal")
np.testing.assert_allclose(fq, f(x, y, z), rtol=rtol, atol=atol)

@pytest.mark.unit
def test_interp3d_vector_valued(self):
"""Test for interpolating vector valued function."""
x = np.linspace(0, np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
z = np.linspace(0, 3, 1000)
xp = np.linspace(0, np.pi, 20)
yp = np.linspace(0, 2 * np.pi, 30)
zp = np.linspace(0, 3, 25)
xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij")

f = lambda x, y, z: np.array(
[np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)]
)
fp = f(xxp.T, yyp.T, zzp.T).T

fq = interp3d(x, y, z, xp, yp, zp, fp, method="nearest")
np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-2, atol=1)

fq = interp3d(x, y, z, xp, yp, zp, fp, method="linear")
np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-3, atol=1e-1)

fq = interp3d(x, y, z, xp, yp, zp, fp, method="cubic")
np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-5, atol=5e-3)


@pytest.mark.unit
def test_fft_interp1d():
Expand Down