Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace generated_jit with overload #112

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11']
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12']

name: Test Interpolation.py (Python ${{ matrix.python-version }})
steps:
Expand Down
1 change: 0 additions & 1 deletion examples/example_mlinterp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

from numba import generated_jit
import ast

C = ((0.1, 0.2), (0.1, 0.2))
Expand Down
44 changes: 36 additions & 8 deletions interpolation/multilinear/fungen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import numpy as np
from numba import float64, int64
from numba import generated_jit, njit
from numba import njit
import ast

from numba.extending import overload
Expand All @@ -25,8 +25,12 @@ def clamp(x, a, b):


# returns the index of a 1d point along a 1d dimension
@generated_jit(nopython=True)
def get_index(gc, x):
pass


@overload(get_index)
def ol_get_index(gc, x):
if gc == t_coord:
# regular coordinate
def fun(gc, x):
Expand All @@ -53,8 +57,12 @@ def fun(gc, x):


# returns number of dimension of a dimension
@generated_jit(nopython=True)
def get_size(gc):
pass


@overload(get_size)
def ol_get_size(gc):
if gc == t_coord:
# regular coordinate
def fun(gc):
Expand Down Expand Up @@ -145,8 +153,12 @@ def _map(*args):
# funzip(((1,2), (2,3), (4,3))) -> ((1,2,4),(2,3,3))


@generated_jit(nopython=True)
def funzip(t):
pass


@overload(funzip)
def ol_funzip(t):
k = t.count
assert len(set([e.count for e in t.types])) == 1
l = t.types[0].count
Expand All @@ -169,8 +181,12 @@ def print_tuple(t):
#####


@generated_jit(nopython=True)
def get_coeffs(X, I):
pass


@overload(get_coeffs)
def ol_get_coeffs(X, I):
if X.ndim > len(I):
print("not implemented yet")
else:
Expand Down Expand Up @@ -218,8 +234,12 @@ def gen_tensor_reduction(X, symbs, inds=[]):
return str.join(" + ", exprs)


@generated_jit(nopython=True)
def tensor_reduction(C, l):
pass


@overload(tensor_reduction)
def ol_tensor_reduction(C, l):
d = len(l.types)
ex = gen_tensor_reduction("C", ["l[{}]".format(i) for i in range(d)])
dd = dict()
Expand All @@ -228,8 +248,12 @@ def tensor_reduction(C, l):
return dd["tensor_reduction"]


@generated_jit(nopython=True)
def extract_row(a, n, tup):
pass


@overload(extract_row)
def ol_extract_row(a, n, tup):
d = len(tup.types)
dd = {}
s = "def extract_row(a, n, tup): return ({},)".format(
Expand All @@ -240,8 +264,12 @@ def extract_row(a, n, tup):


# find closest point inside the grid domain
@generated_jit
def project(grid, point):
pass


@overload(project)
def ol_project(grid, point):
s = "def __project(grid, point):\n"
d = len(grid.types)
for i in range(d):
Expand Down
15 changes: 10 additions & 5 deletions interpolation/multilinear/mlinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,24 @@
)

from numba import njit
from numba.extending import overload
from typing import Tuple

from ..compat import UniTuple, Tuple, Float, Integer, Array

Scalar = (Float, Integer)

import numpy as np
from numba import generated_jit

# logic of multilinear interpolation


@generated_jit
def mlinterp(grid, c, u):
pass


@overload(mlinterp)
def ol_mlinterp(grid, c, u):
if isinstance(u, UniTuple):

def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
Expand Down Expand Up @@ -213,11 +217,12 @@ def {funname}(*args):
return source


from numba import generated_jit
def interp(*args):
pass


@generated_jit(nopython=True)
def interp(*args):
@overload(interp)
def ol_interp(*args):
aa = args[0].types

it = detect_types(aa)
Expand Down
18 changes: 12 additions & 6 deletions interpolation/multilinear/tests/test_multilinear.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from numpy import linspace, array
from numpy.random import random
from numba import typeof
from numba import njit

import numpy as np
from ..fungen import get_index


@njit
def get_index_njit(gc, x):
return get_index(gc, x)


def test_barycentric_indexes():
# irregular grid
gg = np.array([0.0, 1.0])
assert get_index(gg, -0.1) == (0, -0.1)
assert get_index(gg, 0.5) == (0, 0.5)
assert get_index(gg, 1.1) == (0, 1.1)
assert get_index_njit(gg, -0.1) == (0, -0.1)
assert get_index_njit(gg, 0.5) == (0, 0.5)
assert get_index_njit(gg, 1.1) == (0, 1.1)

# regular grid
gg = (0.0, 1.0, 2)
assert get_index(gg, -0.1) == (0, -0.1)
assert get_index(gg, 0.5) == (0, 0.5)
assert get_index(gg, 1.1) == (0, 1.1)
assert get_index_njit(gg, -0.1) == (0, -0.1)
assert get_index_njit(gg, 0.5) == (0, 0.5)
assert get_index_njit(gg, 1.1) == (0, 1.1)


# 2d-vecev-scalar
Expand Down
16 changes: 13 additions & 3 deletions interpolation/splines/eval_cubic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy

from numba import njit
from numba.extending import overload
from .eval_splines import eval_cubic

## the functions in this file provide backward compatibility calls
Expand All @@ -11,19 +13,27 @@
# Compatibility calls #
#######################

from numba import generated_jit
from .codegen import source_to_function


@generated_jit
def get_grid(a, b, n, C):
def _get_grid(a, b, n, C):
pass


@overload(_get_grid)
def ol_get_grid(a, b, n, C):
d = C.ndim
s = "({},)".format(str.join(", ", [f"(a[{k}],b[{k}],n[{k}])" for k in range(d)]))
txt = "def get_grid(a,b,n,C): return {}".format(s)
f = source_to_function(txt)
return f


@njit
def get_grid(a, b, n, C):
return _get_grid(a, b, n, C)


def eval_cubic_spline(a, b, orders, coefs, point):
"""Evaluates a cubic spline at one point

Expand Down
9 changes: 5 additions & 4 deletions interpolation/splines/eval_splines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from numba import jit, generated_jit
from numpy import zeros
from numpy import floor

Expand All @@ -19,7 +18,6 @@
from interpolation.splines.codegen import get_code_spline, source_to_function
from numba.types import UniTuple, float64, Array
from interpolation.splines.codegen import source_to_function
from numba import generated_jit


from ..compat import Tuple, UniTuple
Expand Down Expand Up @@ -50,9 +48,12 @@
### eval spline (main function)


# @generated_jit(inline='always', nopython=True) # doens't work
@generated_jit(nopython=True)
def allocate_output(G, C, P, O):
pass


@overload(allocate_output)
def ol_allocate_output(G, C, P, O):
if C.ndim == len(G) + 1:
# vector valued
if P.ndim == 2:
Expand Down
7 changes: 5 additions & 2 deletions interpolation/splines/hermite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def HermiteInterpolationVect(xvect, x: Vector, y: Vector, yp: Vector):

from numba import njit, types
from numba.extending import overload, register_jitable
from numba import generated_jit


def _hermite(x0, x, y, yp, out=None):
Expand All @@ -102,8 +101,12 @@ def _hermite(x0, x, y, yp, out=None):
from numba.core.types.misc import NoneType as none


@generated_jit
def hermite(x0, x, y, yp, out=None):
pass


@overload(hermite)
def ol_hermite(x0, x, y, yp, out=None):
try:
n = x0.ndim
if n == 1:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ maintainers = [
license = "BSD-2-Clause"

[tool.poetry.dependencies]
python = ">=3.9, <=3.12"
numba = "^0.57"
python = ">=3.9"
numba = ">=0.57"
scipy = "^1.10"

[tool.poetry.dev-dependencies]
Expand Down