Skip to content

Commit

Permalink
Use lineax for tridiagonal solve
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Mar 3, 2024
1 parent 1047afe commit 789b753
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
30 changes: 12 additions & 18 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import lineax as lx
import numpy as np
from jax import jit

Expand Down Expand Up @@ -1228,22 +1229,14 @@ def approx_df(
dxi = jnp.where(dx == 0, 0, 1 / dx)
df = dxi * df

A = jnp.diag(
jnp.concatenate(
(
np.array([1.0]),
2 * (dx.flatten()[:-1] + dx.flatten()[1:]),
np.array([1.0]),
)
)
)
upper_diag1 = jnp.diag(
jnp.concatenate((np.array([1.0]), dx.flatten()[:-1])), k=1
)
lower_diag1 = jnp.diag(
jnp.concatenate((dx.flatten()[1:], np.array([1.0]))), k=-1
)
A += upper_diag1 + lower_diag1
one = jnp.array([1.0])
dxflat = dx.flatten()
diag = jnp.concatenate([one, 2 * (dxflat[:-1] + dxflat[1:]), one])
upper_diag = jnp.concatenate([one, dxflat[:-1]])
lower_diag = jnp.concatenate([dxflat[1:], one])

A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)

b = jnp.concatenate(
[
2 * jnp.take(df, jnp.array([0]), axis, mode="wrap"),
Expand All @@ -1260,8 +1253,9 @@ def approx_df(
)
ba = jnp.moveaxis(b, axis, 0)
br = ba.reshape((b.shape[axis], -1))
fx = jnp.linalg.solve(A, br).reshape(ba.shape)
fx = jnp.moveaxis(fx, 0, axis)
solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value
fx = jnp.vectorize(solve, signature="(n)->(n)")(br.T).T
fx = jnp.moveaxis(fx.reshape(ba.shape), 0, axis)
return fx

elif method in ["cardinal", "catmull-rom"]:
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
-r ./requirements.txt

jax[cpu] >= 0.3.2, <= 0.5.0
scipy >= 1.5.0, < 2.0

# building the docs
sphinx > 3.0.0
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
equinox
jax >= 0.3.2, <= 0.5.0
lineax
numpy >= 1.20.0, < 2.0
scipy >= 1.5.0, < 2.0

0 comments on commit 789b753

Please sign in to comment.