Skip to content

Commit

Permalink
Add cumulative simpson integration (#12)
Browse files Browse the repository at this point in the history
Resolves #9
  • Loading branch information
f0uriest authored Jul 22, 2024
2 parents 39eafa0 + d52a820 commit 09217f2
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Integrating function from sampled values
trapezoid -- Use trapezoidal rule to approximate definite integral.
cumulative_trapezoid -- Use trapezoidal rule to approximate indefinite integral.
simpson -- Use Simpson's rule to compute integral from samples.
cumulative_simpson -- Use Simpson's rule to approximate indefinite integral.


Low level routines and wrappers
Expand Down
2 changes: 1 addition & 1 deletion quadax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
fixed_quadts,
)
from .romberg import romberg, rombergts
from .sampled import cumulative_trapezoid, simpson, trapezoid
from .sampled import cumulative_simpson, cumulative_trapezoid, simpson, trapezoid
from .utils import STATUS

__version__ = _version.get_versions()["version"]
215 changes: 211 additions & 4 deletions quadax/sampled.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Quadrature of functions using known sample values."""

import functools
from typing import Callable, Union

import equinox as eqx
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike


def _tupleset(t, i, value):
Expand All @@ -9,7 +15,10 @@ def _tupleset(t, i, value):
return tuple(l)


def trapezoid(y, x=None, dx=1.0, axis=-1):
@functools.partial(jax.jit, static_argnames="axis")
def trapezoid(
y: ArrayLike, *, x: Union[None, ArrayLike] = None, dx: float = 1.0, axis: int = -1
) -> jax.Array:
r"""
Integrate along the given axis using the composite trapezoidal rule.
Expand Down Expand Up @@ -104,7 +113,15 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)


def cumulative_trapezoid(y, x=None, dx=1.0, axis=-1, initial=None):
@functools.partial(jax.jit, static_argnames="axis")
def cumulative_trapezoid(
y: ArrayLike,
*,
x: Union[None, ArrayLike] = None,
dx: float = 1.0,
axis: int = -1,
initial: Union[ArrayLike, None] = None,
) -> jax.Array:
"""Cumulatively integrate y(x) using the composite trapezoidal rule.
Parameters
Expand Down Expand Up @@ -184,7 +201,9 @@ def cumulative_trapezoid(y, x=None, dx=1.0, axis=-1, initial=None):
return res


def _basic_simpson(y, start, stop, x, dx, axis):
def _basic_simpson(
y: jax.Array, start: int, stop: int, x: Union[jax.Array, None], dx: float, axis: int
) -> jax.Array:
nd = len(y.shape)
if start is None:
start = 0
Expand Down Expand Up @@ -221,7 +240,10 @@ def _basic_simpson(y, start, stop, x, dx, axis):
return result


def simpson(y, x=None, dx=1.0, axis=-1):
@functools.partial(jax.jit, static_argnames="axis")
def simpson(
y: ArrayLike, *, x: Union[None, ArrayLike] = None, dx: float = 1.0, axis: int = -1
) -> jax.Array:
"""Integrate y(x) from samples using the composite Simpson's rule.
If x is None, spacing of dx is assumed.
Expand Down Expand Up @@ -336,3 +358,188 @@ def simpson(y, x=None, dx=1.0, axis=-1):
if returnshape:
x = x.reshape(saveshape)
return result


def cumulative_simpson(
y: ArrayLike,
*,
x: Union[None, ArrayLike] = None,
dx: float = 1.0,
axis: int = -1,
initial: Union[ArrayLike, None] = None,
) -> jax.Array:
r"""Cumulatively integrate y(x) using the composite Simpson's 1/3 rule.
The integral of the samples at every point is calculated by assuming a
quadratic relationship between each point and the two adjacent points.
Parameters
----------
y : array_like
Values to integrate. Requires at least one point along `axis`. If two or fewer
points are provided along `axis`, Simpson's integration is not possible and the
result is calculated with `cumulative_trapezoid`.
x : array_like, optional
The coordinate to integrate along. Must have the same shape as `y` or
must be 1D with the same length as `y` along `axis`. `x` must also be
strictly increasing along `axis`.
If `x` is None (default), integration is performed using spacing `dx`
between consecutive elements in `y`.
dx : scalar or array_like, optional
Spacing between elements of `y`. Only used if `x` is None. Can either
be a float, or an array with the same shape as `y`, but of length one along
`axis`. Default is 1.0.
axis : int, optional
Specifies the axis to integrate along. Default is -1 (last axis).
initial : scalar or array_like, optional
If given, insert this value at the beginning of the returned result,
and add it to the rest of the result. Default is None, which means no
value at ``x[0]`` is returned and `res` has one element less than `y`
along the axis of integration. Can either be a float, or an array with
the same shape as `y`, but of length one along `axis`.
Returns
-------
res : ndarray
The result of cumulative integration of `y` along `axis`.
If `initial` is None, the shape is such that the axis of integration
has one less value than `y`. If `initial` is given, the shape is equal
to that of `y`.
Notes
-----
For an odd number of samples that are equally spaced the result is
exact if the function is a polynomial of order 3 or less. If
the samples are not equally spaced, then the result is exact only
if the function is a polynomial of order 2 or less.
"""
y = _ensure_float_array(y)

# validate `axis` and standardize to work along the last axis
original_y = y
original_shape = y.shape
try:
y = jnp.swapaxes(y, axis, -1)
except IndexError as e:
message = f"`axis={axis}` is not valid for `y` with `y.ndim={y.ndim}`."
raise ValueError(message) from e
if y.shape[-1] < 3:
res = cumulative_trapezoid(original_y, x, dx=dx, axis=axis, initial=None)
res = jnp.swapaxes(res, axis, -1)

elif x is not None:
x = _ensure_float_array(x)
message = (
"If given, shape of `x` must be the same as `y` or 1-D with "
"the same length as `y` along `axis`."
)
if not (
x.shape == original_shape
or (x.ndim == 1 and len(x) == original_shape[axis])
):
raise ValueError(message)

x = jnp.broadcast_to(x, y.shape) if x.ndim == 1 else jnp.swapaxes(x, axis, -1)
dx = jnp.diff(x, axis=-1)
dx = eqx.error_if(dx, dx <= 0, "Input x must be strictly increasing.")
res = _cumulatively_sum_simpson_integrals(
y, dx, _cumulative_simpson_unequal_intervals
)

else:
dx = _ensure_float_array(dx)
final_dx_shape = _tupleset(original_shape, axis, original_shape[axis] - 1)
alt_input_dx_shape = _tupleset(original_shape, axis, 1)
message = (
"If provided, `dx` must either be a scalar or have the same "
"shape as `y` but with only 1 point along `axis`."
)
if not (dx.ndim == 0 or dx.shape == alt_input_dx_shape):
raise ValueError(message)
dx = jnp.broadcast_to(dx, final_dx_shape)
dx = jnp.swapaxes(dx, axis, -1)
res = _cumulatively_sum_simpson_integrals(
y, dx, _cumulative_simpson_equal_intervals
)

if initial is not None:
initial = _ensure_float_array(initial)
alt_initial_input_shape = _tupleset(original_shape, axis, 1)
message = (
"If provided, `initial` must either be a scalar or have the "
"same shape as `y` but with only 1 point along `axis`."
)
if not (initial.ndim == 0 or initial.shape == alt_initial_input_shape):
raise ValueError(message)
initial = jnp.broadcast_to(initial, alt_initial_input_shape)
initial = jnp.swapaxes(initial, axis, -1)

res += initial
res = jnp.concatenate((initial, res), axis=-1)

res = jnp.swapaxes(res, -1, axis)
return res


def _cumulatively_sum_simpson_integrals(
y: jax.Array,
dx: jax.Array,
integration_func: Callable[[jax.Array, jax.Array], jax.Array],
) -> jax.Array:
"""Calculate cumulative sum of Simpson integrals.
Takes as input the integration function to be used.
The integration_func is assumed to return the cumulative sum using
composite Simpson's rule. Assumes the axis of summation is -1.
"""
sub_integrals_h1 = integration_func(y, dx)
sub_integrals_h2 = integration_func(y[..., ::-1], dx[..., ::-1])[..., ::-1]

shape = list(sub_integrals_h1.shape)
shape[-1] += 1
sub_integrals = jnp.empty(shape)
sub_integrals = sub_integrals.at[..., :-1:2].set(sub_integrals_h1[..., ::2])
sub_integrals = sub_integrals.at[..., 1::2].set(sub_integrals_h2[..., ::2])
# Integral over last subinterval can only be calculated from
# formula for h2
sub_integrals = sub_integrals.at[..., -1].set(sub_integrals_h2[..., -1])
res = jnp.cumsum(sub_integrals, axis=-1)
return res


def _cumulative_simpson_equal_intervals(y: jax.Array, dx: jax.Array) -> jax.Array:
"""Calculate the Simpson integrals assuming equal interval widths."""
d = dx[..., :-1]
f1 = y[..., :-2]
f2 = y[..., 1:-1]
f3 = y[..., 2:]

return d / 3 * (5 * f1 / 4 + 2 * f2 - f3 / 4)


def _cumulative_simpson_unequal_intervals(y: jax.Array, dx: jax.Array) -> jax.Array:
"""Calculate the Simpson integrals assuming unequal interval widths."""
x21 = dx[..., :-1]
x32 = dx[..., 1:]
f1 = y[..., :-2]
f2 = y[..., 1:-1]
f3 = y[..., 2:]

x31 = x21 + x32
x21_x31 = x21 / x31
x21_x32 = x21 / x32
x21x21_x31x32 = x21_x31 * x21_x32

coeff1 = 3 - x21_x31
coeff2 = 3 + x21x21_x31x32 + x21_x31
coeff3 = -x21x21_x31x32

return x21 / 6 * (coeff1 * f1 + coeff2 * f2 + coeff3 * f3)


def _ensure_float_array(arr):
arr = jnp.asarray(arr)
if jnp.issubdtype(arr.dtype, jnp.integer):
arr = arr.astype(float, copy=False)
return arr
46 changes: 41 additions & 5 deletions tests/test_sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from jax import config

from quadax import cumulative_trapezoid, simpson, trapezoid
from quadax import cumulative_simpson, cumulative_trapezoid, simpson, trapezoid

config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -52,8 +52,8 @@ def _base(self, i, n, tol):
x2 = a + (b - a) * np.linspace(0, 1, n) ** 2
f1 = prob["fun"](x1)
f2 = prob["fun"](x2)
y1 = trapezoid(f1, x1)
y2 = trapezoid(f2, x2)
y1 = trapezoid(f1, x=x1)
y2 = trapezoid(f2, x=x2)
y3 = trapezoid(f1, dx=np.diff(x1)[0])
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y1, prob["val"], atol=tol, rtol=tol)
Expand Down Expand Up @@ -90,7 +90,7 @@ def _base(self, i, n, tol):
# evenly spaced points
x1 = a + (b - a) * np.linspace(0, 1, n)
f1 = prob["fun"](x1)
y1 = simpson(f1, x1)
y1 = simpson(f1, x=x1)
y3 = simpson(f1, dx=np.diff(x1)[0])
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y1, prob["val"], atol=tol, rtol=tol)
Expand Down Expand Up @@ -126,7 +126,7 @@ def _base(self, i, n, tol):
# evenly spaced points
x1 = a + (b - a) * np.linspace(0, 1, n)
f1 = prob["fun"](x1)
y1 = cumulative_trapezoid(f1, x1, initial=0) + prob["int"](a)
y1 = cumulative_trapezoid(f1, x=x1, initial=0) + prob["int"](a)
y3 = cumulative_trapezoid(f1, dx=np.diff(x1)[0], initial=0) + prob["int"](a)
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y1, prob["int"](x1), atol=tol, rtol=tol)
Expand All @@ -151,3 +151,39 @@ def test_prob2(self):
self._base(2, 20, 3e-2 / 4**1)
self._base(2, 40, 3e-2 / 4**2)
self._base(2, 80, 3e-2 / 4**3)


class TestCumulativeSimpson:
"""Tests for cumulative integration using simpsons rule."""

def _base(self, i, n, tol):
prob = example_problems[i]
a, b = prob["a"], prob["b"]
# evenly spaced points
x1 = a + (b - a) * np.linspace(0, 1, n)
f1 = prob["fun"](x1)
y1 = cumulative_simpson(f1, x=x1, initial=0) + prob["int"](a)
y3 = cumulative_simpson(f1, dx=np.diff(x1)[0], initial=0) + prob["int"](a)
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y1, prob["int"](x1), atol=tol, rtol=tol)

def test_prob0(self):
"""Test integrating log(x)."""
self._base(0, 10, 1e-2 / 8**0)
self._base(0, 20, 1e-2 / 8**1)
self._base(0, 40, 1e-2 / 8**2)
self._base(0, 80, 1e-2 / 8**3)

def test_prob1(self):
"""Test integrating a high order polynomial."""
self._base(1, 10, 2e-2 / 8**0)
self._base(1, 20, 2e-2 / 8**1)
self._base(1, 40, 2e-2 / 8**2)
self._base(1, 80, 2e-2 / 8**3)

def test_prob2(self):
"""Test integrating a gaussian."""
self._base(2, 10, 3e-2 / 8**0)
self._base(2, 20, 3e-2 / 8**1)
self._base(2, 40, 3e-2 / 8**2)
self._base(2, 80, 3e-2 / 8**3)

0 comments on commit 09217f2

Please sign in to comment.