Skip to content

Commit

Permalink
Fixes for complex integrands (#20)
Browse files Browse the repository at this point in the history
- Fixes some type mismatches for complex valued integrands
- Adds full testing for integration of complex valued integrands

Resolves #19
  • Loading branch information
f0uriest authored Dec 7, 2024
2 parents 0cb837c + b28fb80 commit 75fbdd9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ quadax is a library for numerical quadrature and integration using JAX.
- ``vmap``-able, ``jit``-able, differentiable.
- Scalar or vector valued integrands.
- Finite or infinite domains with discontinuities or singularities within the domain of integration.
- Globally adaptive Gauss-Konrod and Clenshaw-Curtis quadrature for smooth integrands (similar to ``scipy.integrate.quad``)
- Globally adaptive Gauss-Kronrod and Clenshaw-Curtis quadrature for smooth integrands (similar to ``scipy.integrate.quad``)
- Adaptive tanh-sinh quadrature for singular or near singular integrands.
- Quadrature from sampled values using trapezoidal and Simpsons methods.

Expand Down
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Adaptive integration of a callable function or method
:toctree: _api/
:recursive:

quadgk -- General purpose integration using Gauss-Konrod scheme
quadgk -- General purpose integration using Gauss-Kronrod scheme
quadcc -- General purpose integration using Clenshaw-Curtis scheme
quadts -- General purpose integration using tanh-sinh (aka double exponential) scheme
romberg -- Adaptive trapezoidal integration with Richardson extrapolation
Expand All @@ -27,7 +27,7 @@ Quadrature Rules
:template: class.rst

AbstractQuadratureRule -- Abstract base class for all quadrature rules
GaussKronrodRule -- Fixed order integration over finite interval using Gauss-Konrod scheme
GaussKronrodRule -- Fixed order integration over finite interval using Gauss-Kronrod scheme
ClenshawCurtisRule -- Fixed order integration over finite interval using Clenshaw-Curtis scheme
TanhSinhRule -- Fixed order integration over finite interval using tanh-sinh (aka double exponential) scheme

Expand Down
10 changes: 6 additions & 4 deletions quadax/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def quadgk(
order=21,
norm=jnp.inf,
):
"""Global adaptive quadrature using Gauss-Konrod rule.
"""Global adaptive quadrature using Gauss-Kronrod rule.
Integrate fun from `interval[0]` to `interval[-1]` using a h-adaptive scheme with
error estimate. Breakpoints can be specified in `interval` where integration
Expand Down Expand Up @@ -429,20 +429,22 @@ def adaptive_quadrature(
state = {}
state["neval"] = 0 # number of evaluations of local quadrature rule
state["ninter"] = len(interval) - 1 # current number of intervals
state["r_arr"] = jnp.zeros((max_ninter, *shape)) # local results from each interval
state["r_arr"] = jnp.zeros(
(max_ninter, *shape), f.dtype
) # local results from each interval
state["e_arr"] = jnp.zeros(max_ninter) # local error est. from each interval
state["a_arr"] = jnp.zeros(max_ninter) # start of each interval
state["b_arr"] = jnp.zeros(max_ninter) # end of each interval
state["s_arr"] = jnp.zeros(
(max_ninter, *shape)
(max_ninter, *shape), f.dtype
) # global est. of I from n intervals
state["a_arr"] = state["a_arr"].at[: state["ninter"]].set(interval[:-1])
state["b_arr"] = state["b_arr"].at[: state["ninter"]].set(interval[1:])
state["roundoff1"] = 0 # for keeping track of roundoff errors
state["roundoff2"] = 0 # for keeping track of roundoff errors
state["status"] = 0 # error flag
state["err_bnd"] = 0.0 # error bound we're trying to reach
state["area"] = jnp.zeros(shape) # current best estimate for I
state["area"] = jnp.zeros(shape, f.dtype) # current best estimate for I
state["err_sum"] = 0.0 # current estimate for error in I

def init_body(i, state_):
Expand Down
2 changes: 1 addition & 1 deletion quadax/fixed_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def integrate(
def truefun():
f = jax.eval_shape(vfun, jnp.array(0.0))
z = jnp.zeros(f.shape, f.dtype)
return z, 0.0, z, z
return z, self.norm(z), jnp.abs(z), jnp.abs(z)

def falsefun():

Expand Down
38 changes: 37 additions & 1 deletion tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,17 @@
"interval": [0, 1, 2],
"val": -4,
},
# problem 16 - complex function
{
"fun": lambda t: t * jnp.log(1 + t) * 1j,
"interval": [0, 1],
"val": 0.25j,
},
]


class TestQuadGK:
"""Tests for Gauss-Konrod quadrature."""
"""Tests for Gauss-Kronrod quadrature."""

def _base(self, i, tol, fudge=1, **kwargs):
prob = example_problems[i]
Expand Down Expand Up @@ -210,6 +216,12 @@ def test_prob15(self):
self._base(14, 1e-8)
self._base(14, 1e-12)

def test_prob16(self):
"""Test for example problem #16."""
self._base(16, 1e-4)
self._base(16, 1e-8)
self._base(16, 1e-12)


class TestQuadCC:
"""Tests for Clenshaw-Curtis quadrature."""
Expand Down Expand Up @@ -331,6 +343,12 @@ def test_prob15(self):
self._base(14, 1e-8)
self._base(14, 1e-12)

def test_prob16(self):
"""Test for example problem #16."""
self._base(16, 1e-4)
self._base(16, 1e-8)
self._base(16, 1e-12)


class TestQuadTS:
"""Tests for adaptive tanh-sinh quadrature."""
Expand Down Expand Up @@ -452,6 +470,12 @@ def test_prob15(self):
self._base(14, 1e-8)
self._base(14, 1e-12)

def test_prob16(self):
"""Test for example problem #16."""
self._base(16, 1e-4)
self._base(16, 1e-8)
self._base(16, 1e-12)


class TestRombergTS:
"""Tests for tanh-sinh quadrature with adaptive refinement."""
Expand Down Expand Up @@ -567,6 +591,12 @@ def test_prob15(self):
self._base(14, 1e-8)
self._base(14, 1e-12)

def test_prob16(self):
"""Test for example problem #16."""
self._base(16, 1e-4)
self._base(16, 1e-8)
self._base(16, 1e-12)


class TestRomberg:
"""Tests for Romberg's method (only for well behaved integrands)."""
Expand Down Expand Up @@ -670,3 +700,9 @@ def test_prob15(self):
self._base(14, 1e-4)
self._base(14, 1e-8)
self._base(14, 1e-12)

def test_prob16(self):
"""Test for example problem #16."""
self._base(16, 1e-4)
self._base(16, 1e-8)
self._base(16, 1e-12)

0 comments on commit 75fbdd9

Please sign in to comment.