Skip to content

Commit

Permalink
Merge pull request #377 from QuantEcon/comb_jit
Browse files Browse the repository at this point in the history
Add Numba jit version of scipy.special.comb
  • Loading branch information
mmcky authored Jan 4, 2018
2 parents 828aaf2 + 0c35d78 commit 475650b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 8 deletions.
24 changes: 20 additions & 4 deletions quantecon/gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import scipy.special
from numba import jit, njit
from .util.numba import comb_jit


def cartesian(nodes, order='C'):
Expand Down Expand Up @@ -124,7 +125,10 @@ def _repeat_1d(x, K, out):
out[ind] = val


@jit
_msg_max_size_exceeded = 'Maximum allowed size exceeded'


@jit(nopython=True, cache=True)
def simplex_grid(m, n):
r"""
Construct an array consisting of the integer points in the
Expand Down Expand Up @@ -196,10 +200,12 @@ def simplex_grid(m, n):
Academic Press, 1978.
"""
L = num_compositions(m, n)
out = np.empty((L, m), dtype=int)
L = num_compositions_jit(m, n)
if L == 0: # Overflow occured
raise ValueError(_msg_max_size_exceeded)
out = np.empty((L, m), dtype=np.int_)

x = np.zeros(m, dtype=int)
x = np.zeros(m, dtype=np.int_)
x[m-1] = n

for j in range(m):
Expand Down Expand Up @@ -282,3 +288,13 @@ def num_compositions(m, n):
"""
# docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html
return scipy.special.comb(n+m-1, m-1, exact=True)


@jit(nopython=True, cache=True)
def num_compositions_jit(m, n):
"""
Numba jit version of `num_compositions`. Return `0` if the outcome
exceeds the maximum value of `np.intp`.
"""
return comb_jit(n+m-1, m-1)
15 changes: 13 additions & 2 deletions quantecon/tests/test_gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
"""
import numpy as np
from numpy.testing import assert_array_equal
from nose.tools import eq_
from nose.tools import eq_, raises
from nose.plugins.attrib import attr

from quantecon.gridtools import (
cartesian, mlinspace, _repeat_1d, simplex_grid, simplex_index,
num_compositions
num_compositions, num_compositions_jit
)


Expand Down Expand Up @@ -212,6 +212,17 @@ def test_num_compositions(self):
num = num_compositions(3, 4)
eq_(num, len(self.simplex_grid_3_4))

def test_num_compositions_jit(self):
num = num_compositions_jit(3, 4)
eq_(num, len(self.simplex_grid_3_4))

eq_(num_compositions_jit(100, 50), 0) # Exceed max value of np.intp


@raises(ValueError)
def test_simplex_grid_raises_value_error_overflow():
simplex_grid(100, 50) # Exceed max value of np.intp


if __name__ == '__main__':
import sys
Expand Down
48 changes: 47 additions & 1 deletion quantecon/util/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import numpy as np
from numba import generated_jit, types
from numba import jit, generated_jit, types
from numba.targets.linalg import _LAPACK


Expand Down Expand Up @@ -69,3 +69,49 @@ def _numba_linalg_solve_impl(a, b): # pragma: no cover
return r

return _numba_linalg_solve_impl


@jit(types.intp(types.intp, types.intp), nopython=True, cache=True)
def comb_jit(N, k):
"""
Numba jitted function that computes N choose k. Return `0` if the
outcome exceeds the maximum value of `np.intp` or if N < 0, k < 0,
or k > N.
Parameters
----------
N : scalar(int)
k : scalar(int)
Returns
-------
val : scalar(int)
"""
# From scipy.special._comb_int_long
# github.com/scipy/scipy/blob/v1.0.0/scipy/special/_comb.pyx
INTP_MAX = np.iinfo(np.intp).max
if N < 0 or k < 0 or k > N:
return 0
if k == 0:
return 1
if k == 1:
return N
if N == INTP_MAX:
return 0

M = N + 1
nterms = min(k, N - k)

val = 1

for j in range(1, nterms + 1):
# Overflow check
if val > INTP_MAX // (M - j):
return 0

val *= M - j
val //= j

return val
26 changes: 25 additions & 1 deletion quantecon/util/tests/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.testing import assert_array_equal
from numba import jit
from nose.tools import eq_, ok_
from quantecon.util.numba import _numba_linalg_solve
from quantecon.util.numba import _numba_linalg_solve, comb_jit


@jit(nopython=True)
Expand Down Expand Up @@ -49,6 +49,30 @@ def test_singular_a(self):
ok_(r != 0)


class TestCombJit:
def setUp(self):
self.MAX_INTP = np.iinfo(np.intp).max

def test_comb(self):
N, k = 10, 3
N_choose_k = 120
eq_(comb_jit(N, k), N_choose_k)

def test_comb_zeros(self):
eq_(comb_jit(2, 3), 0)
eq_(comb_jit(-1, 3), 0)
eq_(comb_jit(2, -1), 0)

eq_(comb_jit(self.MAX_INTP, 2), 0)

N = np.int(self.MAX_INTP**0.5 * 2**0.5) + 1
eq_(comb_jit(N, 2), 0)

def test_max_intp(self):
eq_(comb_jit(self.MAX_INTP, 0), 1)
eq_(comb_jit(self.MAX_INTP, 1), self.MAX_INTP)


if __name__ == '__main__':
import sys
import nose
Expand Down

0 comments on commit 475650b

Please sign in to comment.