Skip to content

Commit

Permalink
Merge pull request #2380 from pybamm-team/add-three-dimensional-inter…
Browse files Browse the repository at this point in the history
…polation

Add three dimensional interpolation
  • Loading branch information
valentinsulzer authored Oct 25, 2022
2 parents 58879c8 + c9e2aca commit ff670f9
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 17 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features

- Added three-dimensional interpolation [#2380](https://github.com/pybamm-team/PyBaMM/pull/2380)

## Bug fixes

- For simulations with events that cause the simulation to stop early, the sensitivities could be evaluated incorrectly to zero ([#2337](https://github.com/pybamm-team/PyBaMM/pull/2337))
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.13 ('python39-pybamm')",
"language": "python",
"name": "python3"
},
Expand All @@ -450,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.13"
},
"toc": {
"base_numbering": 1,
Expand All @@ -464,6 +464,11 @@
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
},
"vscode": {
"interpreter": {
"hash": "7dc94e087d5e42ea54b14035c48a0a59093d5180e7f512a1db8f70eb4b99d01e"
}
}
},
"nbformat": 4,
Expand Down
72 changes: 72 additions & 0 deletions examples/scripts/minimal_interp3d_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np

import pybamm
import matplotlib.pyplot as plt


def f(x, y, z):
return 2 * x**3 + 3 * y**2 - z


x = np.linspace(1, 4, 100)
y = np.linspace(4, 7, 105)
z = np.linspace(7, 9, 110)
xg, yg, zg = np.meshgrid(x, y, z, indexing="ij", sparse=True)
data = f(xg, yg, zg)

x_in = (x, y, z)

model = pybamm.BaseModel()

a = pybamm.Variable("a")
b = pybamm.Variable("b")
c = pybamm.Variable("c")
d = pybamm.Variable("d")

interp = pybamm.Interpolant(x_in, data, (a, b, c), interpolator="linear")

model.rhs = {a: 3, b: 3, c: 2, d: interp} # add to model
model.initial_conditions = {
a: pybamm.Scalar(1),
b: pybamm.Scalar(4),
c: pybamm.Scalar(7),
d: pybamm.Scalar(0),
}

model.variables = {
"Something": interp,
"a": a,
"b": b,
"c": c,
"d": d,
}

# solver = pybamm.CasadiSolver()
sim = pybamm.Simulation(model)

t_eval = np.linspace(0, 1, 100)
sim.solve(t_eval)

a_eval = sim.solution["a"](t_eval)
b_eval = sim.solution["b"](t_eval)
c_eval = sim.solution["c"](t_eval)
d_eval = sim.solution["d"](t_eval)
something = sim.solution["Something"](t_eval)

difference = something - f(a_eval, b_eval, c_eval)

fig, ax = plt.subplots(2, 1, figsize=(10, 5), sharex=True)

ax[0].plot(t_eval, f(a_eval, b_eval, c_eval), label="Original")
ax[0].plot(t_eval, something, label="Interpolated")
ax[0].set_ylabel("Value")
ax[0].legend()

ax[1].plot(t_eval, np.abs(f(a_eval, b_eval, c_eval) - something), label="Original")
ax[1].set_ylabel("Difference")

ax[-1].set_xlabel("Time [s]")
for a in ax:
a.grid()

plt.show()
100 changes: 95 additions & 5 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@

class Interpolant(pybamm.Function):
"""
Interpolate data in 1D.
Interpolate data in 1D, 2D, or 3D. Interpolation in 3D requires the input data to be
on a regular grid (as per scipy.interpolate.RegularGridInterpolator).
Parameters
----------
x : iterable of :class:`numpy.ndarray`
1-D array(s) of real values defining the data point coordinates.
The data point coordinates. If 1-D, then this is an array(s) of real values. If,
2D or 3D interpolation, then this is to ba a tuple of 1D arrays (one for each
dimension) which together define the coordinates of the points.
y : :class:`numpy.ndarray`
The values of the function to interpolate at the data points.
The values of the function to interpolate at the data points. In 2D and 3D, this
should be a matrix of two and three dimensions respectively.
children : iterable of :class:`pybamm.Symbol`
Node(s) to use when evaluating the interpolant. Each child corresponds to an
entry of x
Expand All @@ -26,10 +30,13 @@ class Interpolant(pybamm.Function):
function" is given.
interpolator : str, optional
Which interpolator to use. Can be "linear", "cubic", or "pchip". Default is
"linear".
"linear". For 3D interpolation, only "linear" an "cubic" are currently
supported.
extrapolate : bool, optional
Whether to extrapolate for points that are outside of the parametrisation
range, or return NaN (following default behaviour from scipy). Default is True.
Generally, it is best to set this to be False for 3D interpolation due to
the higher potential for errors in extrapolation.
**Extends**: :class:`pybamm.Function`
"""
Expand Down Expand Up @@ -71,6 +78,26 @@ def __init__(
"len(x2) should equal y=shape[0], "
f"but x2.shape={x2.shape} and y.shape={y.shape}"
)
elif isinstance(x, (tuple, list)) and len(x) == 3:
x1, x2, x3 = x
if y.ndim != 3:
raise ValueError("y should be three-dimensional if len(x)=3")

if x1.shape[0] != y.shape[0]:
raise ValueError(
"len(x1) should equal y=shape[0], "
f"but x1.shape={x1.shape} and y.shape={y.shape}"
)
if x2 is not None and x2.shape[0] != y.shape[1]:
raise ValueError(
"len(x2) should equal y=shape[1], "
f"but x2.shape={x2.shape} and y.shape={y.shape}"
)
if x3 is not None and x3.shape[0] != y.shape[2]:
raise ValueError(
"len(x3) should equal y=shape[2], "
f"but x3.shape={x3.shape} and y.shape={y.shape}"
)
else:
if isinstance(x, (tuple, list)):
x1 = x[0]
Expand Down Expand Up @@ -129,6 +156,28 @@ def __init__(
interpolating_function = interpolate.interp2d(
x1, x2, y, kind=interpolator
)
elif len(x) == 3:
self.dimension = 3

if extrapolate:
fill_value = None
else:
fill_value = np.nan

possible_interpolators = ["linear", "cubic"]
if interpolator not in possible_interpolators:
raise ValueError(
"""interpolator should be 'linear' or 'cubic'
for 3D interpolation"""
)
else:
interpolating_function = interpolate.RegularGridInterpolator(
(x1, x2, x3),
y,
method=interpolator,
bounds_error=False,
fill_value=fill_value,
)
else:
raise ValueError("Invalid dimension of x: {0}".format(len(x)))

Expand Down Expand Up @@ -197,7 +246,48 @@ def _function_evaluate(self, evaluated_children):
if res.ndim > 1:
return np.diagonal(res)[:, np.newaxis]
else:
# raise ValueError("Invalid children dimension: {0}".format(res.ndim))
return res[:, np.newaxis]
elif self.dimension == 3:

# If the children are scalars, we need to add a dimension
shapes = []
for child in evaluated_children:
if isinstance(child, (float, int)):
shapes.append(())
else:
shapes.append(child.shape)
shapes = set(shapes)
shapes.discard(())

if len(shapes) > 1:
raise ValueError(
"All children must have the same shape for 3D interpolation"
)

if len(shapes) == 0:
shape = (1,)
else:
shape = shapes.pop()
new_evaluated_children = []
for child in evaluated_children:

if hasattr(child, "shape") and child.shape == shape:
new_evaluated_children.append(child.flatten())
else:
new_evaluated_children.append(np.reshape(child, shape).flatten())

# return nans if there are any within the children
nans = np.isnan(new_evaluated_children)
if np.any(nans):
nan_children = []
for child, interp_range in zip(
new_evaluated_children, self.function.grid
):
nan_children.append(np.ones_like(child) * interp_range.mean())
return self.function(np.transpose(nan_children)) * np.nan
else:
res = self.function(np.transpose(new_evaluated_children))
return res[:, np.newaxis]

else: # pragma: no cover
raise ValueError("Invalid dimension: {0}".format(self.dimension))
2 changes: 1 addition & 1 deletion pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
)(*converted_children)
elif len(converted_children) == 2:
elif len(converted_children) in [2, 3]:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
)
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,102 @@ def test_interpolation_2_x_2d_y(self):
interp.evaluate(y=np.array([0, 0])), 0, decimal=3
)

def test_interpolation_3_x(self):
def f(x, y, z):
return 2 * x**3 + 3 * y**2 - z

x = np.linspace(1, 4, 11)
y = np.linspace(4, 7, 22)
z = np.linspace(7, 9, 33)
xg, yg, zg = np.meshgrid(x, y, z, indexing="ij", sparse=True)
data = f(xg, yg, zg)

var1 = pybamm.StateVector(slice(0, 1))
var2 = pybamm.StateVector(slice(1, 2))
var3 = pybamm.StateVector(slice(2, 3))

x_in = (x, y, z)
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="linear"
)

value = interp.evaluate(y=np.array([1, 5, 8]))
np.testing.assert_equal(value, f(1, 5, 8))

value = interp.evaluate(y=np.array([[1, 1, 1], [5, 4, 4], [8, 7, 7]]))
np.testing.assert_array_equal(
value, np.array([[f(1, 5, 8)], [f(1, 4, 7)], [f(1, 4, 7)]])
)

# check also works for cubic
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="cubic"
)
value = interp.evaluate(y=np.array([1, 5, 8]))
np.testing.assert_equal(value, f(1, 5, 8))

# Test raising error if data is not 3D
data_4d = np.zeros((11, 22, 33, 5))
with self.assertRaisesRegex(ValueError, "y should be three-dimensional"):
interp = pybamm.Interpolant(
x_in, data_4d, (var1, var2, var3), interpolator="linear"
)

# Test raising error if wrong shapes
with self.assertRaisesRegex(ValueError, "x1.shape"):
interp = pybamm.Interpolant(
x_in, np.zeros((12, 22, 33)), (var1, var2, var3), interpolator="linear"
)

with self.assertRaisesRegex(ValueError, "x2.shape"):
interp = pybamm.Interpolant(
x_in, np.zeros((11, 23, 33)), (var1, var2, var3), interpolator="linear"
)

with self.assertRaisesRegex(ValueError, "x3.shape"):
interp = pybamm.Interpolant(
x_in, np.zeros((11, 22, 34)), (var1, var2, var3), interpolator="linear"
)

# Raise error if not linear
with self.assertRaisesRegex(
ValueError, "interpolator should be 'linear' or 'cubic'"
):
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="pchip"
)

# Check returns nan if extrapolate set to False
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="linear", extrapolate=False
)
value = interp.evaluate(y=np.array([0, 0, 0]))
np.testing.assert_equal(value, np.nan)

# Check testing for shape works (i.e. using nans)
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="cubic"
)
interp.test_shape()

# test with inconsistent children shapes
# (this can occur is one child is a scaler and the others
# are vaiables)
evaluated_children = [np.array([[1]]), 4, np.array([[7]])]
value = interp._function_evaluate(evaluated_children)

evaluated_children = [np.array([[1]]), np.ones(()) * 4, np.array([[7]])]
value = interp._function_evaluate(evaluated_children)

# Test evaluation fails with different child shapes
with self.assertRaisesRegex(ValueError, "All children must"):
evaluated_children = [np.array([[1, 1]]), np.ones(()) * 4, np.array([[7]])]
value = interp._function_evaluate(evaluated_children)

# Test runs when all children are scalsrs
evaluated_children = [1, 4, 7]
value = interp._function_evaluate(evaluated_children)

def test_name(self):
a = pybamm.Symbol("a")
x = np.linspace(0, 1, 200)
Expand Down
Loading

0 comments on commit ff670f9

Please sign in to comment.