Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply @njit to interp and mlinterp #114

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions interpolation/multilinear/mlinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@
# logic of multilinear interpolation


def mlinterp(grid, c, u):
def _mlinterp(grid, c, u):
pass


@overload(mlinterp)
@overload(_mlinterp)
def ol_mlinterp(grid, c, u):
if isinstance(u, UniTuple):

def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
def mlininterp(grid, c, u):
# get indices and barycentric coordinates
tmp = fmap(get_index, grid, u)
indices, barycenters = funzip(tmp)
Expand All @@ -59,7 +59,7 @@ def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:

elif isinstance(u, Array) and u.ndim == 2:

def mlininterp(grid: Tuple, c: Array, u: Array) -> float:
def mlininterp(grid, c, u):
N = u.shape[0]
res = np.zeros(N)
for n in range(N):
Expand All @@ -76,6 +76,11 @@ def mlininterp(grid: Tuple, c: Array, u: Array) -> float:
return mlininterp


@njit
def mlinterp(grid, c, u):
return _mlinterp(grid, c, u)


### The rest of this file constrcts function `interp`

from collections import namedtuple
Expand Down Expand Up @@ -217,15 +222,13 @@ def {funname}(*args):
return source


def interp(*args):
def _interp(*args):
pass


@overload(interp)
@overload(_interp)
def ol_interp(*args):
aa = args[0].types

it = detect_types(aa)
it = detect_types(args)
if it.d == 1 and it.eval == "point":
it = itt(it.d, it.values, "cartesian")
source = make_mlinterp(it, "__mlinterp")
Expand All @@ -235,3 +238,8 @@ def ol_interp(*args):
code = compile(tree, "<string>", "exec")
eval(code, globals())
return __mlinterp


@njit
def interp(*args):
return _interp(*args)
5 changes: 5 additions & 0 deletions interpolation/multilinear/tests/test_multilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def test_mlinterp():
pp = np.random.random((2000, 2))

res0 = mlinterp((x1, x2), y, pp)
assert res0 is not None

res0 = mlinterp((x1, x2), y, (0.1, 0.2))
assert res0 is not None


def test_multilinear():
Expand All @@ -125,6 +128,8 @@ def test_multilinear():
tt = [typeof(e) for e in t]
rr = interp(*t)

assert rr is not None

try:
print(f"{tt}: {rr.shape}")
except:
Expand Down