-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX: support_enumeration: Use
_numba_linalg_solve
(#311)
* support_enumeration: Remove fallback for Numba < 0.28 * support_enumeration: Add a test "LinAlgError: Matrix is singular to machine precision.” raised * FIX: support_enumeration: Use `_numba_linalg_solve` Remove `is_singular` by svd * util: Add `_numba_linalg_solve` For use in a jitted function in nopython mode * Call directly Numba internal `numba_xgesv` * Return nonzero int if input matrix is singular, allowing alternative to try-except np.linalg.LinAlgError * support_enumeration: Remove `any()` Allow `cache=True`, close #285
- Loading branch information
Showing
6 changed files
with
197 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,5 +7,6 @@ Utilities | |
util/array | ||
util/common_messages | ||
util/notebooks | ||
util/numba | ||
util/random | ||
util/timing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
numba | ||
===== | ||
|
||
.. automodule:: quantecon.util.numba | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
Utilities to support Numba jitted functions | ||
""" | ||
import numpy as np | ||
from numba import generated_jit, types | ||
from numba.targets.linalg import _LAPACK | ||
|
||
|
||
# BLAS kinds as letters | ||
_blas_kinds = { | ||
types.float32: 's', | ||
types.float64: 'd', | ||
types.complex64: 'c', | ||
types.complex128: 'z', | ||
} | ||
|
||
|
||
@generated_jit(nopython=True, cache=True) | ||
def _numba_linalg_solve(a, b): | ||
""" | ||
Solve the linear equation ax = b directly calling a Numba internal | ||
function. The data in `a` and `b` are interpreted in Fortran order, | ||
and dtype of `a` and `b` must be the same, one of {float32, float64, | ||
complex64, complex128}. `a` and `b` are modified in place, and the | ||
solution is stored in `b`. *No error check is made for the inputs.* | ||
Parameters | ||
---------- | ||
a : ndarray(ndim=2) | ||
2-dimensional ndarray of shape (n, n). | ||
b : ndarray(ndim=1 or 2) | ||
1-dimensional ndarray of shape (n,) or 2-dimensional ndarray of | ||
shape (n, nrhs). | ||
Returns | ||
------- | ||
r : scalar(int) | ||
r = 0 if successful. | ||
Notes | ||
----- | ||
From github.com/numba/numba/blob/master/numba/targets/linalg.py | ||
""" | ||
numba_xgesv = _LAPACK().numba_xgesv(a.dtype) | ||
kind = ord(_blas_kinds[a.dtype]) | ||
|
||
def _numba_linalg_solve_impl(a, b): # pragma: no cover | ||
n = a.shape[-1] | ||
if b.ndim == 1: | ||
nrhs = 1 | ||
else: # b.ndim == 2 | ||
nrhs = b.shape[-1] | ||
F_INT_nptype = np.int32 | ||
ipiv = np.empty(n, dtype=F_INT_nptype) | ||
|
||
r = numba_xgesv( | ||
kind, # kind | ||
n, # n | ||
nrhs, # nhrs | ||
a.ctypes, # a | ||
n, # lda | ||
ipiv.ctypes, # ipiv | ||
b.ctypes, # b | ||
n # ldb | ||
) | ||
return r | ||
|
||
return _numba_linalg_solve_impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
Tests for Numba support utilities | ||
""" | ||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
from numba import jit | ||
from nose.tools import eq_, ok_ | ||
from quantecon.util.numba import _numba_linalg_solve | ||
|
||
|
||
@jit(nopython=True) | ||
def numba_linalg_solve_orig(a, b): | ||
return np.linalg.solve(a, b) | ||
|
||
|
||
class TestNumbaLinalgSolve: | ||
def setUp(self): | ||
self.dtypes = [np.float32, np.float64] | ||
self.a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]]) | ||
self.b_1dim = np.array([2, 4, -1]) | ||
self.b_2dim = np.array([[2, 3], [4, 1], [-1, 0]]) | ||
self.a_singular = np.array([[0, 1, 2], [3, 4, 5], [3, 5, 7]]) | ||
|
||
def test_b_1dim(self): | ||
for dtype in self.dtypes: | ||
a = np.asfortranarray(self.a, dtype=dtype) | ||
b = np.asfortranarray(self.b_1dim, dtype=dtype) | ||
sol_orig = numba_linalg_solve_orig(a, b) | ||
r = _numba_linalg_solve(a, b) | ||
eq_(r, 0) | ||
assert_array_equal(b, sol_orig) | ||
|
||
def test_b_2dim(self): | ||
for dtype in self.dtypes: | ||
a = np.asfortranarray(self.a, dtype=dtype) | ||
b = np.asfortranarray(self.b_2dim, dtype=dtype) | ||
sol_orig = numba_linalg_solve_orig(a, b) | ||
r = _numba_linalg_solve(a, b) | ||
eq_(r, 0) | ||
assert_array_equal(b, sol_orig) | ||
|
||
def test_singular_a(self): | ||
for b in [self.b_1dim, self.b_2dim]: | ||
for dtype in self.dtypes: | ||
a = np.asfortranarray(self.a_singular, dtype=dtype) | ||
b = np.asfortranarray(b, dtype=dtype) | ||
r = _numba_linalg_solve(a, b) | ||
ok_(r != 0) | ||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
import nose | ||
|
||
argv = sys.argv[:] | ||
argv.append('--verbose') | ||
argv.append('--nocapture') | ||
nose.main(argv=argv, defaultTest=__file__) |