From f9f6f644510a323b1f7fac7b246d35e9e5c39f4b Mon Sep 17 00:00:00 2001 From: Daisuke Oyama Date: Sun, 17 Mar 2024 16:02:42 +0900 Subject: [PATCH] Replace `generated_jit` with `overload` --- .github/workflows/ci.yaml | 2 +- examples/example_mlinterp.py | 1 - interpolation/multilinear/fungen.py | 44 +++++++++++++++---- interpolation/multilinear/mlinterp.py | 15 ++++--- .../multilinear/tests/test_multilinear.py | 18 +++++--- interpolation/splines/eval_cubic.py | 16 +++++-- interpolation/splines/eval_splines.py | 9 ++-- interpolation/splines/hermite.py | 7 ++- pyproject.toml | 4 +- 9 files changed, 84 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6695aa0..073e1b7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10', '3.11'] + python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12'] name: Test Interpolation.py (Python ${{ matrix.python-version }}) steps: diff --git a/examples/example_mlinterp.py b/examples/example_mlinterp.py index 7745017..25adcfa 100644 --- a/examples/example_mlinterp.py +++ b/examples/example_mlinterp.py @@ -1,6 +1,5 @@ import numpy as np -from numba import generated_jit import ast C = ((0.1, 0.2), (0.1, 0.2)) diff --git a/interpolation/multilinear/fungen.py b/interpolation/multilinear/fungen.py index d0f13e8..a7a9ec1 100644 --- a/interpolation/multilinear/fungen.py +++ b/interpolation/multilinear/fungen.py @@ -1,7 +1,7 @@ import numba import numpy as np from numba import float64, int64 -from numba import generated_jit, njit +from numba import njit import ast from numba.extending import overload @@ -25,8 +25,12 @@ def clamp(x, a, b): # returns the index of a 1d point along a 1d dimension -@generated_jit(nopython=True) def get_index(gc, x): + pass + + +@overload(get_index) +def ol_get_index(gc, x): if gc == t_coord: # regular coordinate def fun(gc, x): @@ -53,8 +57,12 @@ def fun(gc, x): # returns number of dimension of a dimension -@generated_jit(nopython=True) def get_size(gc): + pass + + +@overload(get_size) +def ol_get_size(gc): if gc == t_coord: # regular coordinate def fun(gc): @@ -145,8 +153,12 @@ def _map(*args): # funzip(((1,2), (2,3), (4,3))) -> ((1,2,4),(2,3,3)) -@generated_jit(nopython=True) def funzip(t): + pass + + +@overload(funzip) +def ol_funzip(t): k = t.count assert len(set([e.count for e in t.types])) == 1 l = t.types[0].count @@ -169,8 +181,12 @@ def print_tuple(t): ##### -@generated_jit(nopython=True) def get_coeffs(X, I): + pass + + +@overload(get_coeffs) +def ol_get_coeffs(X, I): if X.ndim > len(I): print("not implemented yet") else: @@ -218,8 +234,12 @@ def gen_tensor_reduction(X, symbs, inds=[]): return str.join(" + ", exprs) -@generated_jit(nopython=True) def tensor_reduction(C, l): + pass + + +@overload(tensor_reduction) +def ol_tensor_reduction(C, l): d = len(l.types) ex = gen_tensor_reduction("C", ["l[{}]".format(i) for i in range(d)]) dd = dict() @@ -228,8 +248,12 @@ def tensor_reduction(C, l): return dd["tensor_reduction"] -@generated_jit(nopython=True) def extract_row(a, n, tup): + pass + + +@overload(extract_row) +def ol_extract_row(a, n, tup): d = len(tup.types) dd = {} s = "def extract_row(a, n, tup): return ({},)".format( @@ -240,8 +264,12 @@ def extract_row(a, n, tup): # find closest point inside the grid domain -@generated_jit def project(grid, point): + pass + + +@overload(project) +def ol_project(grid, point): s = "def __project(grid, point):\n" d = len(grid.types) for i in range(d): diff --git a/interpolation/multilinear/mlinterp.py b/interpolation/multilinear/mlinterp.py index 50d2e0c..c5dae9e 100644 --- a/interpolation/multilinear/mlinterp.py +++ b/interpolation/multilinear/mlinterp.py @@ -29,6 +29,7 @@ ) from numba import njit +from numba.extending import overload from typing import Tuple from ..compat import UniTuple, Tuple, Float, Integer, Array @@ -36,13 +37,16 @@ Scalar = (Float, Integer) import numpy as np -from numba import generated_jit # logic of multilinear interpolation -@generated_jit def mlinterp(grid, c, u): + pass + + +@overload(mlinterp) +def ol_mlinterp(grid, c, u): if isinstance(u, UniTuple): def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float: @@ -213,11 +217,12 @@ def {funname}(*args): return source -from numba import generated_jit +def interp(*args): + pass -@generated_jit(nopython=True) -def interp(*args): +@overload(interp) +def ol_interp(*args): aa = args[0].types it = detect_types(aa) diff --git a/interpolation/multilinear/tests/test_multilinear.py b/interpolation/multilinear/tests/test_multilinear.py index b852dd9..1139b72 100644 --- a/interpolation/multilinear/tests/test_multilinear.py +++ b/interpolation/multilinear/tests/test_multilinear.py @@ -1,23 +1,29 @@ from numpy import linspace, array from numpy.random import random from numba import typeof +from numba import njit import numpy as np from ..fungen import get_index +@njit +def get_index_njit(gc, x): + return get_index(gc, x) + + def test_barycentric_indexes(): # irregular grid gg = np.array([0.0, 1.0]) - assert get_index(gg, -0.1) == (0, -0.1) - assert get_index(gg, 0.5) == (0, 0.5) - assert get_index(gg, 1.1) == (0, 1.1) + assert get_index_njit(gg, -0.1) == (0, -0.1) + assert get_index_njit(gg, 0.5) == (0, 0.5) + assert get_index_njit(gg, 1.1) == (0, 1.1) # regular grid gg = (0.0, 1.0, 2) - assert get_index(gg, -0.1) == (0, -0.1) - assert get_index(gg, 0.5) == (0, 0.5) - assert get_index(gg, 1.1) == (0, 1.1) + assert get_index_njit(gg, -0.1) == (0, -0.1) + assert get_index_njit(gg, 0.5) == (0, 0.5) + assert get_index_njit(gg, 1.1) == (0, 1.1) # 2d-vecev-scalar diff --git a/interpolation/splines/eval_cubic.py b/interpolation/splines/eval_cubic.py index c532606..6e88ce8 100644 --- a/interpolation/splines/eval_cubic.py +++ b/interpolation/splines/eval_cubic.py @@ -1,5 +1,7 @@ import numpy +from numba import njit +from numba.extending import overload from .eval_splines import eval_cubic ## the functions in this file provide backward compatibility calls @@ -11,12 +13,15 @@ # Compatibility calls # ####################### -from numba import generated_jit from .codegen import source_to_function -@generated_jit -def get_grid(a, b, n, C): +def _get_grid(a, b, n, C): + pass + + +@overload(_get_grid) +def ol_get_grid(a, b, n, C): d = C.ndim s = "({},)".format(str.join(", ", [f"(a[{k}],b[{k}],n[{k}])" for k in range(d)])) txt = "def get_grid(a,b,n,C): return {}".format(s) @@ -24,6 +29,11 @@ def get_grid(a, b, n, C): return f +@njit +def get_grid(a, b, n, C): + return _get_grid(a, b, n, C) + + def eval_cubic_spline(a, b, orders, coefs, point): """Evaluates a cubic spline at one point diff --git a/interpolation/splines/eval_splines.py b/interpolation/splines/eval_splines.py index 39368d8..2e2adc7 100644 --- a/interpolation/splines/eval_splines.py +++ b/interpolation/splines/eval_splines.py @@ -1,5 +1,4 @@ import numpy as np -from numba import jit, generated_jit from numpy import zeros from numpy import floor @@ -19,7 +18,6 @@ from interpolation.splines.codegen import get_code_spline, source_to_function from numba.types import UniTuple, float64, Array from interpolation.splines.codegen import source_to_function -from numba import generated_jit from ..compat import Tuple, UniTuple @@ -50,9 +48,12 @@ ### eval spline (main function) -# @generated_jit(inline='always', nopython=True) # doens't work -@generated_jit(nopython=True) def allocate_output(G, C, P, O): + pass + + +@overload(allocate_output) +def ol_allocate_output(G, C, P, O): if C.ndim == len(G) + 1: # vector valued if P.ndim == 2: diff --git a/interpolation/splines/hermite.py b/interpolation/splines/hermite.py index fa0c7a3..aadfab5 100644 --- a/interpolation/splines/hermite.py +++ b/interpolation/splines/hermite.py @@ -84,7 +84,6 @@ def HermiteInterpolationVect(xvect, x: Vector, y: Vector, yp: Vector): from numba import njit, types from numba.extending import overload, register_jitable -from numba import generated_jit def _hermite(x0, x, y, yp, out=None): @@ -102,8 +101,12 @@ def _hermite(x0, x, y, yp, out=None): from numba.core.types.misc import NoneType as none -@generated_jit def hermite(x0, x, y, yp, out=None): + pass + + +@overload(hermite) +def ol_hermite(x0, x, y, yp, out=None): try: n = x0.ndim if n == 1: diff --git a/pyproject.toml b/pyproject.toml index b276529..5444d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,8 @@ maintainers = [ license = "BSD-2-Clause" [tool.poetry.dependencies] -python = ">=3.9, <=3.12" -numba = "^0.57" +python = ">=3.9" +numba = ">=0.57" scipy = "^1.10" [tool.poetry.dev-dependencies]