Skip to content

Commit

Permalink
Clean up docs
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Oct 30, 2023
1 parent 15e54d2 commit edbf6a2
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 58 deletions.
5 changes: 3 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ Usage
f = lambda t: t * jnp.log(1 + t)
y, err = quadgk(fun, 0, 1, epsabs=1e-14, epsrel=1e-14)
epsabs = epsrel = 1e-14
y, info = quadgk(fun, 0, 1, epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=1e-14, atol=1e-14)
Expand Down
67 changes: 40 additions & 27 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,51 @@
API Documentation
=================

quadgk
******
.. autofunction:: quadax.quadgk
.. currentmodule:: quadax

quadcc
******
.. autofunction:: quadax.quadcc
Adaptive integration of a callable function or method
-----------------------------------------------------

quadts
******
.. autofunction:: quadax.quadts
.. autosummary::
:toctree: _api/
:recursive:

romberg
*******
.. autofunction:: quadax.romberg
quadgk -- General purpose integration using Gauss-Konrod scheme
quadcc -- General purpose integration using Clenshaw-Curtis scheme
quadts -- General purpose integration using tanh-sinh (aka double exponential) scheme
romberg -- Adaptive trapezoidal integration with Richardson extrapolation
rombergts -- Adaptive tanh-sinh integration with Richardson extrapolation

rombergts
*********
.. autofunction:: quadax.rombergts

adaptive_quadrature
*******************
.. autofunction:: quadax.adaptive_quadrature
Fixed order integration of a callable function or method
--------------------------------------------------------

trapezoid
*********
.. autofunction:: quadax.trapezoid
.. autosummary::
:toctree: _api/
:recursive:

cumulative_trapezoid
********************
.. autofunction:: quadax.cumulative_trapezoid
fixed_quadgk -- Fixed order integration over finite interval using Gauss-Konrod scheme
fixed_quadcc -- Fixed order integration over finite interval using Clenshaw-Curtis scheme
fixed_quadts -- Fixed order integration over finite interval using tanh-sinh (aka double exponential) scheme

simpson
*******
.. autofunction:: quadax.simpson

Integrating function from sampled values
----------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

trapezoid -- Use trapezoidal rule to approximate definite integral.
cumulative_trapezoid -- Use trapezoidal rule to approximate indefinite integral.
simpson -- Use Simpson's rule to compute integral from samples.


Low level routines and wrappers
-------------------------------

.. autosummary::
:toctree: _api/
:recursive:

adaptive_quadrature -- Custom h-adaptive quadrature using user specified local rule.
35 changes: 34 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,41 @@
.. include:: ../README.rst


Which method should I choose?
=============================
Can you evaluate the integrand at an arbitary point?
----------------------------------------------------

To start, ``quadgk`` or ``quadcc`` are probably your best options, and are similar to
methods in QUADPACK (or ``scipy.integrate.quad``). ``quadgk`` is usually the most efficient
for very smooth integrands (well approximated by a high degree polynomial), ``quadcc``
tends to be slightly more efficient for less smooth integrands. If both of those don't
perform well, you should think about your integrand a bit more:

- Does your integrand have badly behaved singularites at the endpoints? Use ``quadts`` or ``rombergts``
- Is your integrand only piecewise smooth or piecewise continuous? Use ``romberg`` or ``rombergts``

Do you only know your integrand at discrete points?
---------------------------------------------------
- Use ``trapezoid`` or ``simspson``


Notes on parallel efficiency
============================
Adaptive algorithms are inherently somewhat sequential, so perfect parallelism
is generally not achievable. ``romberg`` and ``rombergts`` are fully sequential, due to
limitiations on dynamically sized arrays in JAX. All of the ``quad*`` methods are parallelized
on a local level (ie, for each sub-interval, the function evaluations are vectorized).
This means that ``quad*`` methods will evaluate the integrand in batch sizes of ``order``,
and hence higher order methods will tend to be more efficient on GPU/TPU. However, if the
integrand is not sufficiently smooth, using a higher order method can slow down convergence,
particularly for ``quadgk``, ``quadts`` and ``quadcc`` are somewhat less sensitive to the
smoothness of the integrand.



.. toctree::
:maxdepth: 2
:maxdepth: 4
:caption: Public API

api
Expand Down
38 changes: 19 additions & 19 deletions quadax/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def quadgk(
"""
y, info = adaptive_quadrature(
fun, a, b, args, full_output, epsabs, epsrel, max_ninter, fixed_quadgk, n=order
fixed_quadgk, fun, a, b, args, full_output, epsabs, epsrel, max_ninter, n=order
)
info = QuadratureInfo(info.err, info.neval * order, info.status, info.info)
return y, info
Expand Down Expand Up @@ -137,7 +137,7 @@ def quadcc(
max_ninter : int, optional
An upper bound on the number of sub-intervals used in the adaptive
algorithm.
n : {8, 16, 32, 64, 128, 256}
order : {8, 16, 32, 64, 128, 256}
Order of local integration rule.
Returns
Expand Down Expand Up @@ -177,7 +177,7 @@ def quadcc(
"""
y, info = adaptive_quadrature(
fun, a, b, args, full_output, epsabs, epsrel, max_ninter, fixed_quadcc, n=order
fixed_quadcc, fun, a, b, args, full_output, epsabs, epsrel, max_ninter, n=order
)
info = QuadratureInfo(info.err, info.neval * order, info.status, info.info)
return y, info
Expand Down Expand Up @@ -220,7 +220,7 @@ def quadts(
max_ninter : int, optional
An upper bound on the number of sub-intervals used in the adaptive
algorithm.
n : {41, 61, 81, 101}
order : {41, 61, 81, 101}
Order of local integration rule.
Returns
Expand Down Expand Up @@ -260,13 +260,14 @@ def quadts(
"""
y, info = adaptive_quadrature(
fun, a, b, args, full_output, epsabs, epsrel, max_ninter, fixed_quadts, n=order
fixed_quadts, fun, a, b, args, full_output, epsabs, epsrel, max_ninter, n=order
)
info = QuadratureInfo(info.err, info.neval * order, info.status, info.info)
return y, info


def adaptive_quadrature(
rule,
fun,
a,
b,
Expand All @@ -275,17 +276,26 @@ def adaptive_quadrature(
epsabs=1.4e-8,
epsrel=1.4e-8,
max_ninter=50,
rule=None,
**kwargs
):
"""Global adaptive quadrature.
Integrate fun from a to b using an adaptive scheme with error estimate.
Differentiation wrt args is done via Leibniz rule.
This is a lower level routine allowing for custom local quadrature rules. For most
applications the higher order methods ``quadgk``, ``quadcc``, ``quadts`` are
preferable.
Parameters
----------
rule : callable
Local quadrature rule to use. It should have a signature of the form
``rule(fun, a, b, **kwargs)`` -> out, where out is array-like with 4 elements:
#. Estimate of the integral of fun from a to b
#. Estimate of the absolute error in the integral (ie, from nested scheme).
#. Estimate of the integral of abs(fun) from a to b
#. Estimate of the integral of abs(fun - <fun>) from a to b, where <fun> is
the mean value of fun over the interval.
fun : callable
Function to integrate, should have a signature of the form
``fun(x, *args)`` -> float. Should be JAX transformable.
Expand All @@ -304,16 +314,6 @@ def adaptive_quadrature(
max_ninter : int, optional
An upper bound on the number of sub-intervals used in the adaptive
algorithm.
rule : callable
Local quadrature rule to use. It should have a signature of the form
``rule(fun, a, b, **kwargs)`` -> out, where out is array-like with 4 elements:
#. Estimate of the integral of fun from a to b
#. Estimate of the absolute error in the integral (ie, from nested scheme).
#. Estimate of the integral of abs(fun) from a to b
#. Estimate of the integral of abs(fun - <fun>) from a to b, where <fun> is
the mean value of fun over the interval.
kwargs : dict
Additional keyword arguments passed to ``rule``.
Expand Down
18 changes: 9 additions & 9 deletions quadax/fixed_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@


@functools.partial(jax.jit, static_argnums=(0, 4))
def fixed_quadgk(fun, a, b, args, n=21):
def fixed_quadgk(fun, a, b, args=(), n=21):
"""Integrate a function from a to b using a fixed order Gauss-Konrod rule.
Integration is performed using and order n Konrod rule with error estimated
Integration is performed using an order n Konrod rule with error estimated
using an embedded n//2 order Gauss rule.
Parameters
----------
fun : callable
Function to integrate, should have a signature of the form
fun(x, *args) -> float. Should be JAX transformable.
``fun(x, *args)`` -> float. Should be JAX transformable.
a, b : float
Lower and upper limits of integration. Must be finite.
args : tuple, optional
Expand Down Expand Up @@ -89,17 +89,17 @@ def falsefun():
return jax.lax.cond(a == b, truefun, falsefun)


def fixed_quadcc(fun, a, b, args, n=32):
def fixed_quadcc(fun, a, b, args=(), n=32):
"""Integrate a function from a to b using a fixed order Clenshaw-Curtis rule.
Integration is performed using and order n rule with error estimated
Integration is performed using an order n rule with error estimated
using an embedded n//2 order rule.
Parameters
----------
fun : callable
Function to integrate, should have a signature of the form
fun(x, *args) -> float. Should be JAX transformable.
``fun(x, *args)`` -> float. Should be JAX transformable.
a, b : float
Lower and upper limits of integration. Must be finite.
args : tuple, optional
Expand Down Expand Up @@ -171,17 +171,17 @@ def falsefun():
return jax.lax.cond(a == b, truefun, falsefun)


def fixed_quadts(fun, a, b, args, n=61):
def fixed_quadts(fun, a, b, args=(), n=61):
"""Integrate a function from a to b using a fixed order tanh-sinh rule.
Integration is performed using and order n rule with error estimated
Integration is performed using an order n rule with error estimated
using an embedded n//2 order rule.
Parameters
----------
fun : callable
Function to integrate, should have a signature of the form
fun(x, *args) -> float. Should be JAX transformable.
``fun(x, *args)`` -> float. Should be JAX transformable.
a, b : float
Lower and upper limits of integration. Must be finite.
args : tuple, optional
Expand Down

0 comments on commit edbf6a2

Please sign in to comment.