Skip to content

Commit

Permalink
Add clenshaw curtis rules
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Oct 27, 2023
1 parent a9d41d0 commit 4291d27
Show file tree
Hide file tree
Showing 10 changed files with 698 additions and 267 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ quadgk
******
.. autofunction:: quadax.quadgk

quadcc
******
.. autofunction:: quadax.quadcc

quadts
******
.. autofunction:: quadax.quadts
Expand Down
5 changes: 3 additions & 2 deletions quadax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""quadax : numerical quadrature with JAX."""

from . import _version
from .adaptive import adaptive_quadrature, quadgk
from .fixed_qk import fixed_quadgk
from .adaptive import adaptive_quadrature, quadcc, quadgk
from .fixed_order import fixed_quadcc, fixed_quadgk
from .romberg import romberg
from .sampled import cumulative_trapezoid, simpson, trapezoid
from .tanhsinh import quadts
from .utils import STATUS

__version__ = _version.get_versions()["version"]
113 changes: 89 additions & 24 deletions quadax/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax
import jax.numpy as jnp

from .fixed_qk import fixed_quadgk
from .utils import map_interval
from .fixed_qk import fixed_quadcc, fixed_quadgk
from .utils import map_interval, wrap_func

NORMAL_EXIT = 0
MAX_NINTER = 1
Expand All @@ -29,7 +29,9 @@ def quadgk(
Integrate fun from a to b using a h-adaptive scheme with error estimate.
Differentiation wrt args is done via Liebniz rule.
Basically the same as ``scipy.integrate.quad`` but without extrapolation. A good
general purpose integrator for most reasonably well behaved functions over finite
or infinite intervals.
Parameters
----------
Expand Down Expand Up @@ -86,6 +88,79 @@ def quadgk(
return out


def quadcc(
fun,
a,
b,
args=(),
full_output=False,
epsabs=1.4e-8,
epsrel=1.4e-8,
max_ninter=50,
order=32,
):
"""Global adaptive quadrature using Clenshaw-Curtis rule.
Integrate fun from a to b using a h-adaptive scheme with error estimate.
A good general purpose integrator for most reasonably well behaved functions over
finite or infinite intervals.
Parameters
----------
fun : callable
Function to integrate, should have a signature of the form
``fun(x, *args)`` -> float. Should be JAX transformable.
a, b : float
Lower and upper limits of integration. Use np.inf to denote infinite intervals.
args : tuple, optional
Extra arguments passed to fun.
full_output : bool, optional
If True, return the full state of the integrator. See below for more
information.
epsabs, epsrel : float, optional
Absolute and relative error tolerance. Default is 1.4e-8. Algorithm tries to
obtain an accuracy of ``abs(i-result) <= max(epsabs, epsrel*abs(i))``
where ``i`` = integral of `fun` from `a` to `b`, and ``result`` is the
numerical approximation.
max_ninter : int, optional
An upper bound on the number of sub-intervals used in the adaptive
algorithm.
n : {8, 16, 32, 64, 128, 256}
Order of local integration rule.
Returns
-------
y : float
The integral of fun from `a` to `b`.
err : float
An estimate of the absolute error in the result.
state : dict
Final state of the algorithm. Only returned if full_output=True
The entries are:
- 'neval' : (int) The number of function evaluations.
- 'ninter' : (int) The number, K, of sub-intervals produced in the subdivision
process.
- 'a_arr' : (ndarray) rank-1 array of length max_ninter, the first K elements
of which are the left end points of the (remapped) sub-intervals in the
partition of the integration range.
- 'b_arr' : (ndarray) rank-1 array of length max_ninter, the first K elements of
which are the right end points of the (remapped) sub-intervals.
- 'r_arr' : (ndarray) rank-1 array of length max_ninter, the first K elements of
which are the integral approximations on the sub-intervals.
- 'e_arr' : (ndarray) rank-1 array of length max_ninter, the first K elements of
which are the moduli of the absolute error estimates on the sub-intervals.
"""
out = adaptive_quadrature(
fun, a, b, args, full_output, epsabs, epsrel, max_ninter, fixed_quadcc, n=order
)
if full_output:
out[2]["neval"] *= order
return out


def adaptive_quadrature(
fun,
a,
Expand All @@ -102,7 +177,7 @@ def adaptive_quadrature(
Integrate fun from a to b using an adaptive scheme with error estimate.
Differentiation wrt args is done via Liebniz rule.
Differentiation wrt args is done via Leibniz rule.
Parameters
----------
Expand Down Expand Up @@ -160,11 +235,10 @@ def adaptive_quadrature(
which are the moduli of the absolute error estimates on the sub-intervals.
"""
vfunc = jax.jit(jnp.vectorize(lambda x: fun(x, *args)))
vfunc = map_interval(vfunc, a, b)
fun = map_interval(fun, a, b)
vfunc = wrap_func(fun, args)

f = vfunc(jnp.array([0.5 * (a + b)])) # call it once to get dtype info
uflow = jnp.finfo(f.dtype).tiny
epmach = jnp.finfo(f.dtype).eps
a, b = -1, 1

Expand Down Expand Up @@ -196,13 +270,11 @@ def adaptive_quadrature(
state["s_arr"] = state["s_arr"].at[0].set(result)

# check for roundoff error - error too big but relative error is small
state["status"] += (
2**ROUNDOFF
* ROUNDOFF
* ((abserr <= (100.0 * epmach * intabs)) & (abserr > state["err_bnd"]))
state["status"] += 2**ROUNDOFF * (
(abserr <= (100.0 * epmach * intabs)) & (abserr > state["err_bnd"])
)
# check for max intervals exceeded
state["status"] += 2**MAX_NINTER * MAX_NINTER * (max_ninter == 0)
state["status"] += 2**MAX_NINTER * (max_ninter == 0)

def condfun(state):
return (
Expand Down Expand Up @@ -246,23 +318,16 @@ def bodyfun(state):
) & (erro12 >= 0.99 * jnp.max(state["e_arr"]))
# are errors getting larger as we go to smaller intervals?
state["roundoff2"] += (n > 10) & (erro12 > jnp.max(state["e_arr"]))
state["status"] += (
2**ROUNDOFF
* ROUNDOFF
* ((state["roundoff1"] >= 10) | (state["roundoff2"] >= 20))
state["status"] += 2**ROUNDOFF * (
(state["roundoff1"] >= 10) | (state["roundoff2"] >= 20)
)

# test for max number of intervals
state["status"] += 2**MAX_NINTER * MAX_NINTER * (n == max_ninter)
state["status"] += 2**MAX_NINTER * (n == max_ninter)

# test for bad behavior of the integrand (ie, intervals are getting too small)
state["status"] += (
2**BAD_INTEGRAND
* BAD_INTEGRAND
* (
jnp.maximum(jnp.abs(a1), jnp.abs(b2))
<= (1.0 + 100.0 * epmach) * (jnp.abs(a2) + 1000.0 * uflow)
)
state["status"] += 2**BAD_INTEGRAND * (
jnp.maximum(jnp.abs(b1 - a1), jnp.abs(b2 - a2)) <= (100.0 * epmach)
)

# update the arrays of interval starts/ends etc
Expand Down
171 changes: 171 additions & 0 deletions quadax/fixed_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Fixed order quadrature."""

import functools

import jax
import jax.numpy as jnp

from .quad_weights import cc_weights, gk_weights
from .utils import wrap_func


@functools.partial(jax.jit, static_argnums=(0, 4))
def fixed_quadgk(fun, a, b, args, n=21):
"""Integrate a function from a to b using a fixed order Gauss-Konrod rule.
Integration is performed using and order n Konrod rule with error estimated
using an embedded n//2 order Gauss rule.
Parameters
----------
fun : callable
Function to integrate, should have a signature of the form
fun(x, *args) -> float. Should be JAX transformable.
a, b : float
Lower and upper limits of integration. Must be finite.
args : tuple, optional
Extra arguments passed to fun.
n : {15, 21, 31, 41, 51, 61}
Order of integration scheme.
Returns
-------
y : float
Estimate of the integral of fun from a to b
err : float
Estimate of the absolute error in y from nested Gauss rule.
y_abs : float
Estimate of the integral of abs(fun) from a to b
y_mmn : float
Estimate of the integral of abs(fun - <fun>) from a to b, where <fun>
is the mean value of fun over the interval.
"""
vfun = wrap_func(fun, args)

def truefun():
return 0.0, 0.0, 0.0, 0.0

def falsefun():
try:
xk, wk, wg = (
gk_weights[n]["xk"],
gk_weights[n]["wk"],
gk_weights[n]["wg"],
)
except KeyError as e:
raise NotImplementedError(
f"order {n} not implemented, should be one of {gk_weights.keys()}"
) from e

halflength = (b - a) / 2
center = (b + a) / 2
f = vfun(center + halflength * xk)
result_konrod = jnp.sum(wk * f) * halflength
result_gauss = jnp.sum(wg * f) * halflength

integral_abs = jnp.sum(wk * jnp.abs(f)) # ~integral of abs(fun)
integral_mmn = jnp.sum(
wk * jnp.abs(f - result_konrod / (b - a))
) # ~ integral of abs(fun - mean(fun))

result = result_konrod

uflow = jnp.finfo(f.dtype).tiny
eps = jnp.finfo(f.dtype).eps
abserr = jnp.abs(result_konrod - result_gauss)
abserr = jnp.where(
(integral_mmn != 0.0) & (abserr != 0.0),
integral_mmn * jnp.minimum(1.0, (200.0 * abserr / integral_mmn) ** 1.5),
abserr,
)
abserr = jnp.where(
(integral_abs > uflow / (50.0 * eps)),
jnp.maximum((eps * 50.0) * integral_abs, abserr),
abserr,
)
return result, abserr, integral_abs, integral_mmn

return jax.lax.cond(a == b, truefun, falsefun)


def fixed_quadcc(fun, a, b, args, n=32):
"""Integrate a function from a to b using a fixed order Clenshaw-Curtis rule.
Integration is performed using and order n rule with error estimated
using an embedded n//2 order rule.
Parameters
----------
fun : callable
Function to integrate, should have a signature of the form
fun(x, *args) -> float. Should be JAX transformable.
a, b : float
Lower and upper limits of integration. Must be finite.
args : tuple, optional
Extra arguments passed to fun.
n : {8, 16, 32, 64, 128, 256}
Order of integration scheme.
Returns
-------
y : float
Estimate of the integral of fun from a to b
err : float
Estimate of the absolute error in y from nested rule.
y_abs : float
Estimate of the integral of abs(fun) from a to b
y_mmn : float
Estimate of the integral of abs(fun - <fun>) from a to b, where <fun>
is the mean value of fun over the interval.
"""
vfun = wrap_func(fun, args)

def truefun():
return 0.0, 0.0, 0.0, 0.0

def falsefun():
try:
xc, wc, we = (
cc_weights[n]["xc"],
cc_weights[n]["wc"],
cc_weights[n]["we"],
)
except KeyError as e:
raise NotImplementedError(
f"order {n} not implemented, should be one of {cc_weights.keys()}"
) from e

halflength = (b - a) / 2
center = (b + a) / 2
fp = vfun(center + halflength * xc)
fm = vfun(center - halflength * xc)
result_2 = jnp.sum(wc * (fp + fm)) * halflength
result_1 = jnp.sum(we * (fp + fm)) * halflength

integral_abs = jnp.sum(
wc * (jnp.abs(fp) + jnp.abs(fm))
) # ~integral of abs(fun)
integral_mmn = jnp.sum(
wc * jnp.abs(fp + fm - result_2 / (b - a))
) # ~ integral of abs(fun - mean(fun))

result = result_2

uflow = jnp.finfo(fp.dtype).tiny
eps = jnp.finfo(fp.dtype).eps
abserr = jnp.abs(result_2 - result_1)
abserr = jnp.where(
(integral_mmn != 0.0) & (abserr != 0.0),
integral_mmn * jnp.minimum(1.0, (200.0 * abserr / integral_mmn) ** 1.5),
abserr,
)
abserr = jnp.where(
(integral_abs > uflow / (50.0 * eps)),
jnp.maximum((eps * 50.0) * integral_abs, abserr),
abserr,
)
return result, abserr, integral_abs, integral_mmn

return jax.lax.cond(a == b, truefun, falsefun)
Loading

0 comments on commit 4291d27

Please sign in to comment.