quadax is a library for numerical quadrature and integration using JAX.
vmap
-able,jit
-able, differentiable.- Scalar or vector valued integrands.
- Finite or infinite domains with discontinuities or singularities within the domain of integration.
- Globally adaptive Gauss-Kronrod and Clenshaw-Curtis quadrature for smooth integrands (similar to
scipy.integrate.quad
) - Adaptive tanh-sinh quadrature for singular or near singular integrands.
- Quadrature from sampled values using trapezoidal and Simpsons methods.
Coming soon:
- Custom JVP/VJP rules (currently AD works by differentiating the loop which isn't the most efficient.)
- N-D quadrature (cubature)
- QMC methods
- Integration with weight functions
- Sparse grids (maybe, need to play with data structures and JAX)
quadax is installable with pip:
pip install quadax
import jax.numpy as jnp
import numpy as np
from quadax import quadgk
fun = lambda t: t * jnp.log(1 + t)
epsabs = epsrel = 1e-5 # by default jax uses 32 bit, higher accuracy requires going to 64 bit
a, b = 0, 1
y, info = quadgk(fun, [a, b], epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=epsrel, atol=epsabs)
For full details of various options see the API documentation