From 3e2f3c1743cc68e661fbea6c82c13a379f46abda Mon Sep 17 00:00:00 2001 From: Pablo Winant Date: Mon, 18 Mar 2024 14:54:59 +0100 Subject: [PATCH 1/4] TEST: nonempty result As proposed by @kp992, check that multilinear call returns non-None result --- interpolation/multilinear/tests/test_multilinear.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/interpolation/multilinear/tests/test_multilinear.py b/interpolation/multilinear/tests/test_multilinear.py index 1139b72..bd0a905 100644 --- a/interpolation/multilinear/tests/test_multilinear.py +++ b/interpolation/multilinear/tests/test_multilinear.py @@ -115,7 +115,10 @@ def test_mlinterp(): pp = np.random.random((2000, 2)) res0 = mlinterp((x1, x2), y, pp) + assert res0 is not None + res0 = mlinterp((x1, x2), y, (0.1, 0.2)) + assert res0 is not None def test_multilinear(): @@ -125,6 +128,8 @@ def test_multilinear(): tt = [typeof(e) for e in t] rr = interp(*t) + assert rr is not None + try: print(f"{tt}: {rr.shape}") except: From fb0a074d098e6f040642896a4686adf3a7d58b5d Mon Sep 17 00:00:00 2001 From: Daisuke Oyama Date: Tue, 19 Mar 2024 14:54:49 +0900 Subject: [PATCH 2/4] Apply `@njit` to `interp` and `mlinterp` --- interpolation/multilinear/mlinterp.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/interpolation/multilinear/mlinterp.py b/interpolation/multilinear/mlinterp.py index c5dae9e..e04265b 100644 --- a/interpolation/multilinear/mlinterp.py +++ b/interpolation/multilinear/mlinterp.py @@ -41,11 +41,11 @@ # logic of multilinear interpolation -def mlinterp(grid, c, u): +def _mlinterp(grid, c, u): pass -@overload(mlinterp) +@overload(_mlinterp) def ol_mlinterp(grid, c, u): if isinstance(u, UniTuple): @@ -76,6 +76,11 @@ def mlininterp(grid: Tuple, c: Array, u: Array) -> float: return mlininterp +@njit +def mlinterp(grid, c, u): + return _mlinterp(grid, c, u) + + ### The rest of this file constrcts function `interp` from collections import namedtuple @@ -217,11 +222,11 @@ def {funname}(*args): return source -def interp(*args): +def _interp(*args): pass -@overload(interp) +@overload(_interp) def ol_interp(*args): aa = args[0].types @@ -235,3 +240,8 @@ def ol_interp(*args): code = compile(tree, "", "exec") eval(code, globals()) return __mlinterp + + +@njit +def interp(*args): + return _interp(*args) From dd22fe59bc51474a19fc439c731216a72aa9a79b Mon Sep 17 00:00:00 2001 From: Daisuke Oyama Date: Wed, 20 Mar 2024 14:38:42 +0900 Subject: [PATCH 3/4] Remove type annotations --- interpolation/multilinear/mlinterp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/interpolation/multilinear/mlinterp.py b/interpolation/multilinear/mlinterp.py index e04265b..5e23e9b 100644 --- a/interpolation/multilinear/mlinterp.py +++ b/interpolation/multilinear/mlinterp.py @@ -49,7 +49,7 @@ def _mlinterp(grid, c, u): def ol_mlinterp(grid, c, u): if isinstance(u, UniTuple): - def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float: + def mlininterp(grid, c, u): # get indices and barycentric coordinates tmp = fmap(get_index, grid, u) indices, barycenters = funzip(tmp) @@ -59,7 +59,7 @@ def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float: elif isinstance(u, Array) and u.ndim == 2: - def mlininterp(grid: Tuple, c: Array, u: Array) -> float: + def mlininterp(grid, c, u): N = u.shape[0] res = np.zeros(N) for n in range(N): From 7003f0fd625dab76364091b788ef9b095ad5a0cd Mon Sep 17 00:00:00 2001 From: Daisuke Oyama Date: Thu, 21 Mar 2024 10:42:20 +0900 Subject: [PATCH 4/4] Pass `args` directly to `detect_types` in `ol_interp` --- interpolation/multilinear/mlinterp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/interpolation/multilinear/mlinterp.py b/interpolation/multilinear/mlinterp.py index 5e23e9b..29ef09f 100644 --- a/interpolation/multilinear/mlinterp.py +++ b/interpolation/multilinear/mlinterp.py @@ -228,9 +228,7 @@ def _interp(*args): @overload(_interp) def ol_interp(*args): - aa = args[0].types - - it = detect_types(aa) + it = detect_types(args) if it.d == 1 and it.eval == "point": it = itt(it.d, it.values, "cartesian") source = make_mlinterp(it, "__mlinterp")