Skip to content

Commit

Permalink
Add notes to docstrings, fix romberg extrapolation
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Oct 29, 2023
1 parent 11d5f2f commit 15e54d2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 50 deletions.
20 changes: 20 additions & 0 deletions quadax/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def quadgk(
elements of which are the moduli of the absolute error estimates on the
sub-intervals.
Notes
-----
Adaptive algorithms are inherently somewhat sequential, so perfect parallelism
is generally not achievable. The local quadrature rule vmaps integrand evaluation at
``order`` points, so using higher order methods will generally be more efficient on
GPU/TPU.
"""
y, info = adaptive_quadrature(
fun, a, b, args, full_output, epsabs, epsrel, max_ninter, fixed_quadgk, n=order
Expand Down Expand Up @@ -161,6 +168,13 @@ def quadcc(
elements of which are the moduli of the absolute error estimates on the
sub-intervals.
Notes
-----
Adaptive algorithms are inherently somewhat sequential, so perfect parallelism
is generally not achievable. The local quadrature rule vmaps integrand evaluation at
``order`` points, so using higher order methods will generally be more efficient on
GPU/TPU.
"""
y, info = adaptive_quadrature(
fun, a, b, args, full_output, epsabs, epsrel, max_ninter, fixed_quadcc, n=order
Expand Down Expand Up @@ -237,6 +251,12 @@ def quadts(
elements of which are the moduli of the absolute error estimates on the
sub-intervals.
Notes
-----
Adaptive algorithms are inherently somewhat sequential, so perfect parallelism
is generally not achievable. The local quadrature rule vmaps integrand evaluation at
``order`` points, so using higher order methods will generally be more efficient on
GPU/TPU.
"""
y, info = adaptive_quadrature(
Expand Down
37 changes: 23 additions & 14 deletions quadax/romberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def romberg(
epsabs=1.4e-8,
epsrel=1.4e-8,
divmax=20,
extrap=True,
):
"""Romberg integration of a callable function or method.
Expand All @@ -41,9 +40,6 @@ def romberg(
Not recommended for infinite intervals, or functions with singularities.
Algorithm is copied from SciPy, but in practice tends to underestimate the error
for even mildly bad integrands, sometimes by several orders of magnitude.
Parameters
----------
fun : callable
Expand All @@ -61,11 +57,9 @@ def romberg(
successive approximations to the integral, algorithm terminates
when abs(I1-I2) < max(epsabs, epsrel*|I2|)
divmax : int, optional
Maximum order of extrapolation. Default is 10.
Maximum order of extrapolation. Default is 20.
Total number of function evaluations will be at
most 2**divmax + 1
extrap : bool, optional
Whether to perform Richardson extrapolation.
Returns
-------
Expand All @@ -85,6 +79,14 @@ def romberg(
* table : (ndarray, size(dixmax+1, divmax+1)) Estimate of the integral
from each level of discretization and each step of extrapolation.
Notes
-----
Due to limitations on dynamically sized arrays in JAX, this algorithm is fully
sequential and does not vectorize integrand evaluations, so may not be the most
efficient on GPU/TPU.
Also, it is currently only forward mode differentiable.
"""
# map a, b -> [-1, 1]
fun = map_interval(fun, a, b)
Expand All @@ -102,11 +104,14 @@ def ncond(state):
return (n < divmax + 1) & (err > jnp.maximum(epsabs, epsrel * result[n, n]))

def nloop(state):
# loop over outer number of subdivisions
result, n, neval, err = state
h = (b - a) / 2**n
s = 0.0

def sloop(i, s):
# loop to evaluate fun. Can't be vectorized due to different number
# of evals per nloop step
s += vfunc(a + h * (2 * i - 1))
return s

Expand All @@ -118,12 +123,8 @@ def sloop(i, s):

def mloop(m, result):
# richardson extrapolation
temp = (
1
/ (4.0**m - 1.0)
* ((4.0**m) * result[n, m - 1] - result[n - 1, m - 1])
)
result = result.at[n, m].set(temp)
temp = 1 / (4.0**m - 1.0) * (result[n, m - 1] - result[n - 1, m - 1])
result = result.at[n, m].set(result[n, m - 1] + temp)
return result

result = jax.lax.fori_loop(1, n + 1, mloop, result)
Expand Down Expand Up @@ -168,7 +169,7 @@ def rombergts(
successive approximations to the integral, algorithm terminates
when abs(I1-I2) < max(epsabs, epsrel*|I2|)
divmax : int, optional
Maximum order of extrapolation. Default is 10.
Maximum order of extrapolation. Default is 20.
Total number of function evaluations will be at
most 2**divmax + 1
Expand All @@ -190,6 +191,14 @@ def rombergts(
* table : (ndarray, size(dixmax+1, divmax+1)) Estimate of the integral
from each level of discretization and each step of extrapolation.
Notes
-----
Due to limitations on dynamically sized arrays in JAX, this algorithm is fully
sequential and does not vectorize integrand evaluations, so may not be the most
efficient on GPU/TPU.
Also, it is currently only forward mode differentiable.
"""
# map a, b -> [-1, 1]
fun = map_interval(fun, a, b)
Expand Down
56 changes: 20 additions & 36 deletions tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ def test_prob5(self):
def test_prob6(self):
"""Test for example problem #6."""
self._base(6, 1e-4)
self._base(6, 1e-8, 10)
self._base(6, 1e-12, 1e5, divmax=22)
self._base(6, 1e-8, fudge=10)
self._base(6, 1e-12, divmax=22, fudge=1e5)

def test_prob7(self):
"""Test for example problem #7."""
Expand All @@ -490,8 +490,8 @@ def test_prob8(self):
def test_prob9(self):
"""Test for example problem #9."""
self._base(9, 1e-4)
self._base(9, 1e-8, 10)
self._base(9, 1e-12, 1e5)
self._base(9, 1e-8, fudge=10)
self._base(9, 1e-12, fudge=1e5)

def test_prob10(self):
"""Test for example problem #10."""
Expand All @@ -502,8 +502,8 @@ def test_prob10(self):
def test_prob11(self):
"""Test for example problem #11."""
self._base(11, 1e-4)
self._base(11, 1e-8, 10)
self._base(11, 1e-12, 1e5, divmax=25)
self._base(11, 1e-8, fudge=10)
self._base(11, 1e-12, fudge=1e5)

def test_prob12(self):
"""Test for example problem #12."""
Expand All @@ -521,10 +521,13 @@ def test_prob13(self):
class TestRomberg:
"""Tests for Romberg's method (only for well behaved integrands)."""

def _base(self, i, tol, fudge=1):
def _base(self, i, tol, fudge=1, **kwargs):
prob = example_problems[i]
y, info = romberg(prob["fun"], prob["a"], prob["b"], epsabs=tol, epsrel=tol)
assert info.err < max(tol, tol * y)
y, info = romberg(
prob["fun"], prob["a"], prob["b"], epsabs=tol, epsrel=tol, **kwargs
)
if info.status == 0:
assert info.err < max(tol, tol * y)
np.testing.assert_allclose(
y,
prob["val"],
Expand Down Expand Up @@ -559,66 +562,47 @@ def test_prob3(self):

def test_prob4(self):
"""Test for example problem #4."""
self._base(4, 1e-4, 100)
self._base(4, 1e-8, 1e5)
self._base(4, 1e-12, 1e7)
self._base(4, 1e-4)
self._base(4, 1e-8)
self._base(4, 1e-12, divmax=27)

def test_prob5(self):
"""Test for example problem #5."""
self._base(5, 1e-4, 10)
self._base(5, 1e-8, 1e4)
self._base(5, 1e-12, 1e6)
self._base(5, 1e-4)
self._base(5, 1e-8)
self._base(5, 1e-12, divmax=25)

@pytest.mark.xfail
def test_prob6(self):
"""Test for example problem #6."""
self._base(6, 1e-4)
self._base(6, 1e-8)
self._base(6, 1e-12)
self._base(6, 1e-4, fudge=10)

@pytest.mark.xfail
def test_prob7(self):
"""Test for example problem #7."""
self._base(7, 1e-4)
self._base(7, 1e-8)
self._base(7, 1e-12)

@pytest.mark.xfail
def test_prob8(self):
"""Test for example problem #8."""
self._base(8, 1e-4)
self._base(8, 1e-8)
self._base(8, 1e-12)

@pytest.mark.xfail
def test_prob9(self):
"""Test for example problem #9."""
self._base(9, 1e-4)
self._base(9, 1e-8)
self._base(9, 1e-12)

@pytest.mark.xfail
def test_prob10(self):
"""Test for example problem #10."""
self._base(10, 1e-4)
self._base(10, 1e-8)
self._base(10, 1e-12)

@pytest.mark.xfail
def test_prob11(self):
"""Test for example problem #11."""
self._base(11, 1e-4)
self._base(11, 1e-8)
self._base(11, 1e-12)
self._base(11, 1e-4, fudge=10)

@pytest.mark.xfail
def test_prob12(self):
"""Test for example problem #12."""
self._base(12, 1e-4)
self._base(12, 1e-8)
self._base(12, 1e-12)

@pytest.mark.xfail
def test_prob13(self):
"""Test for example problem #13."""
self._base(13, 1e-4)
Expand Down

0 comments on commit 15e54d2

Please sign in to comment.