diff --git a/interpolation/multilinear/mlinterp.py b/interpolation/multilinear/mlinterp.py index c5dae9e..29ef09f 100644 --- a/interpolation/multilinear/mlinterp.py +++ b/interpolation/multilinear/mlinterp.py @@ -41,15 +41,15 @@ # logic of multilinear interpolation -def mlinterp(grid, c, u): +def _mlinterp(grid, c, u): pass -@overload(mlinterp) +@overload(_mlinterp) def ol_mlinterp(grid, c, u): if isinstance(u, UniTuple): - def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float: + def mlininterp(grid, c, u): # get indices and barycentric coordinates tmp = fmap(get_index, grid, u) indices, barycenters = funzip(tmp) @@ -59,7 +59,7 @@ def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float: elif isinstance(u, Array) and u.ndim == 2: - def mlininterp(grid: Tuple, c: Array, u: Array) -> float: + def mlininterp(grid, c, u): N = u.shape[0] res = np.zeros(N) for n in range(N): @@ -76,6 +76,11 @@ def mlininterp(grid: Tuple, c: Array, u: Array) -> float: return mlininterp +@njit +def mlinterp(grid, c, u): + return _mlinterp(grid, c, u) + + ### The rest of this file constrcts function `interp` from collections import namedtuple @@ -217,15 +222,13 @@ def {funname}(*args): return source -def interp(*args): +def _interp(*args): pass -@overload(interp) +@overload(_interp) def ol_interp(*args): - aa = args[0].types - - it = detect_types(aa) + it = detect_types(args) if it.d == 1 and it.eval == "point": it = itt(it.d, it.values, "cartesian") source = make_mlinterp(it, "__mlinterp") @@ -235,3 +238,8 @@ def ol_interp(*args): code = compile(tree, "", "exec") eval(code, globals()) return __mlinterp + + +@njit +def interp(*args): + return _interp(*args) diff --git a/interpolation/multilinear/tests/test_multilinear.py b/interpolation/multilinear/tests/test_multilinear.py index 1139b72..bd0a905 100644 --- a/interpolation/multilinear/tests/test_multilinear.py +++ b/interpolation/multilinear/tests/test_multilinear.py @@ -115,7 +115,10 @@ def test_mlinterp(): pp = np.random.random((2000, 2)) res0 = mlinterp((x1, x2), y, pp) + assert res0 is not None + res0 = mlinterp((x1, x2), y, (0.1, 0.2)) + assert res0 is not None def test_multilinear(): @@ -125,6 +128,8 @@ def test_multilinear(): tt = [typeof(e) for e in t] rr = interp(*t) + assert rr is not None + try: print(f"{tt}: {rr.shape}") except: