From fa14060bd35f97bfdbe2c3fb640f53f950c2dc4e Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 3 Dec 2024 10:50:55 -0500 Subject: [PATCH] Fix benign bug in midpoint evaluation. Resolves #16 --- quadax/adaptive.py | 2 +- quadax/romberg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/quadax/adaptive.py b/quadax/adaptive.py index 04d2f42..c929698 100644 --- a/quadax/adaptive.py +++ b/quadax/adaptive.py @@ -422,7 +422,7 @@ def adaptive_quadrature( epsrel = setdefault(epsrel, jnp.sqrt(jnp.finfo(jnp.array(1.0)).eps)) fun, interval = map_interval(fun, interval) vfunc = wrap_func(fun, args) - f = jax.eval_shape(vfunc, (interval[0] + interval[-1] / 2)) + f = jax.eval_shape(vfunc, (interval[0] + interval[-1]) / 2) epmach = jnp.finfo(f.dtype).eps shape = f.shape diff --git a/quadax/romberg.py b/quadax/romberg.py index 1e2b445..573ee4f 100644 --- a/quadax/romberg.py +++ b/quadax/romberg.py @@ -97,7 +97,7 @@ def romberg( fun, interval = map_interval(fun, interval) vfunc = wrap_func(fun, args) a, b = interval - f = jax.eval_shape(vfunc, (a + b / 2)) + f = jax.eval_shape(vfunc, (a + b) / 2) result = jnp.zeros((divmax + 1, divmax + 1, *f.shape), f.dtype) result = result.at[0, 0].set(vfunc(a) + vfunc(b))