-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding jitted scalar maximization routine, first build #416
Changes from 1 commit
152c067
4384579
c9933d3
22dce1d
e890f71
f722b25
70be953
44df5df
7526895
7f24600
d221eb1
09b5531
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
Initialization of the optimize subpackage | ||
""" | ||
|
||
from .scalar_maximization import maximize_scalar | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import numpy as np | ||
from numba import jit, njit | ||
|
||
@njit | ||
def maximize_scalar(func, a, b, xtol=1e-5, maxiter=500): | ||
""" | ||
Uses a jitted version of the maximization routine from SciPy's fminbound. | ||
The algorithm is identical except that it's been switched to maximization | ||
rather than minimization, and the tests for convergence have been stripped | ||
out to allow for jit compilation. | ||
|
||
Note that the input function `func` must be jitted or the call will fail. | ||
|
||
Parameters | ||
---------- | ||
maxiter : int, optional | ||
Maximum number of iterations to perform. | ||
xtol : float, optional | ||
Absolute error in solution `xopt` acceptable for convergence. | ||
func : jitted function | ||
a : scalar | ||
Lower bound for search | ||
b : scalar | ||
Upper bound for search | ||
|
||
Returns | ||
------- | ||
fval : float | ||
The maximum value attained | ||
xf : float | ||
The maxizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean "maximizer"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @QBatista ! |
||
|
||
Example | ||
------- | ||
|
||
``` | ||
@njit | ||
def f(x): | ||
return -(x + 2.0)**2 + 1.0 | ||
|
||
fval, xf = maximize_scalar(f, -2, 2) | ||
``` | ||
|
||
""" | ||
maxfun = maxiter | ||
|
||
sqrt_eps = np.sqrt(2.2e-16) | ||
golden_mean = 0.5 * (3.0 - np.sqrt(5.0)) | ||
|
||
fulc = a + golden_mean * (b - a) | ||
nfc, xf = fulc, fulc | ||
rat = e = 0.0 | ||
x = xf | ||
fx = -func(x) | ||
num = 1 | ||
fmin_data = (1, xf, fx) | ||
|
||
ffulc = fnfc = fx | ||
xm = 0.5 * (a + b) | ||
tol1 = sqrt_eps * np.abs(xf) + xtol / 3.0 | ||
tol2 = 2.0 * tol1 | ||
|
||
while (np.abs(xf - xm) > (tol2 - 0.5 * (b - a))): | ||
golden = 1 | ||
# Check for parabolic fit | ||
if np.abs(e) > tol1: | ||
golden = 0 | ||
r = (xf - nfc) * (fx - ffulc) | ||
q = (xf - fulc) * (fx - fnfc) | ||
p = (xf - fulc) * q - (xf - nfc) * r | ||
q = 2.0 * (q - r) | ||
if q > 0.0: | ||
p = -p | ||
q = np.abs(q) | ||
r = e | ||
e = rat | ||
|
||
# Check for acceptability of parabola | ||
if ((np.abs(p) < np.abs(0.5*q*r)) and (p > q*(a - xf)) and | ||
(p < q * (b - xf))): | ||
rat = (p + 0.0) / q | ||
x = xf + rat | ||
|
||
if ((x - a) < tol2) or ((b - x) < tol2): | ||
si = np.sign(xm - xf) + ((xm - xf) == 0) | ||
rat = tol1 * si | ||
else: # do a golden section step | ||
golden = 1 | ||
|
||
if golden: # Do a golden-section step | ||
if xf >= xm: | ||
e = a - xf | ||
else: | ||
e = b - xf | ||
rat = golden_mean*e | ||
|
||
if rat == 0: | ||
si = np.sign(rat) + 1 | ||
else: | ||
si = np.sign(rat) | ||
|
||
x = xf + si * np.maximum(np.abs(rat), tol1) | ||
fu = -func(x) | ||
num += 1 | ||
fmin_data = (num, x, fu) | ||
|
||
if fu <= fx: | ||
if x >= xf: | ||
a = xf | ||
else: | ||
b = xf | ||
fulc, ffulc = nfc, fnfc | ||
nfc, fnfc = xf, fx | ||
xf, fx = x, fu | ||
else: | ||
if x < xf: | ||
a = x | ||
else: | ||
b = x | ||
if (fu <= fnfc) or (nfc == xf): | ||
fulc, ffulc = nfc, fnfc | ||
nfc, fnfc = x, fu | ||
elif (fu <= ffulc) or (fulc == xf) or (fulc == nfc): | ||
fulc, ffulc = x, fu | ||
|
||
xm = 0.5 * (a + b) | ||
tol1 = sqrt_eps * np.abs(xf) + xtol / 3.0 | ||
tol2 = 2.0 * tol1 | ||
|
||
if num >= maxfun: | ||
break | ||
|
||
fval = -fx | ||
|
||
return fval, xf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a status flag ( |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
Tests for scalar maximization. | ||
|
||
""" | ||
import numpy as np | ||
from numpy.testing import assert_almost_equal | ||
from numba import njit | ||
|
||
from quantecon.optimize import maximize_scalar | ||
|
||
@njit | ||
def f(x): | ||
""" | ||
A function for testing on. | ||
""" | ||
return -(x + 2.0)**2 + 1.0 | ||
|
||
def test_maximize_scalar(): | ||
""" | ||
Uses the function f defined above to test the scalar maximization | ||
routine. | ||
""" | ||
true_fval = 1.0 | ||
true_xf = -2.0 | ||
fval, xf = maximize_scalar(f, -2, 2) | ||
assert_almost_equal(true_fval, fval, decimal=4) | ||
assert_almost_equal(true_xf, xf, decimal=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
import nose | ||
|
||
argv = sys.argv[:] | ||
argv.append('--verbose') | ||
argv.append('--nocapture') | ||
nose.main(argv=argv, defaultTest=__file__) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
""" | ||
This is a VERSION file and should NOT be manually altered | ||
""" | ||
version = '0.3.7' | ||
version = '0.3.8' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will be set in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be helpful to add
*args
(to pass tofunc
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oyamad thanks I have added this, although unfortunately you need to be quite careful to pass in a tuple and not a scalar, as I can't figure out how to check the type inside a jitted function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't see what you mean. Can you elaborate?
And any difference between adding
args=()
and*args
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In scipy, if something in
args
is passed that is not a tuple (ie. a scalar), the function will convert it to a tuple. I don't seem to be able to getisinstance
to work inside a jitted function. If you only want to set one fixed argument, you need to passargs=(y,)
which is somewhat annoying.I guess we could use
*args
- I was just following scipy's style.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see thanks.
Following exactly scipy's style makes sense, while
*args
looks more Pythonic, allowing passing e.g.y=5
. I can't say which is better...