Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add three dimensional interpolation #2380

Merged
merged 20 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features

- Added three-dimensional interpolation [#2380](https://github.com/pybamm-team/PyBaMM/pull/2380)

## Bug fixes

- For simulations with events that cause the simulation to stop early, the sensitivities could be evaluated incorrectly to zero ([#2337](https://github.com/pybamm-team/PyBaMM/pull/2337))
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.13 ('python39-pybamm')",
"language": "python",
"name": "python3"
},
Expand All @@ -450,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.13"
},
"toc": {
"base_numbering": 1,
Expand All @@ -464,6 +464,11 @@
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
},
"vscode": {
"interpreter": {
"hash": "7dc94e087d5e42ea54b14035c48a0a59093d5180e7f512a1db8f70eb4b99d01e"
}
}
},
"nbformat": 4,
Expand Down
72 changes: 72 additions & 0 deletions examples/scripts/minimal_interp3d_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np

import pybamm
import matplotlib.pyplot as plt


def f(x, y, z):
return 2 * x**3 + 3 * y**2 - z


x = np.linspace(1, 4, 100)
y = np.linspace(4, 7, 105)
z = np.linspace(7, 9, 110)
xg, yg, zg = np.meshgrid(x, y, z, indexing="ij", sparse=True)
data = f(xg, yg, zg)

x_in = (x, y, z)

model = pybamm.BaseModel()

a = pybamm.Variable("a")
b = pybamm.Variable("b")
c = pybamm.Variable("c")
d = pybamm.Variable("d")

interp = pybamm.Interpolant(x_in, data, (a, b, c), interpolator="linear")

model.rhs = {a: 3, b: 3, c: 2, d: interp} # add to model
model.initial_conditions = {
a: pybamm.Scalar(1),
b: pybamm.Scalar(4),
c: pybamm.Scalar(7),
d: pybamm.Scalar(0),
}

model.variables = {
"Something": interp,
"a": a,
"b": b,
"c": c,
"d": d,
}

# solver = pybamm.CasadiSolver()
sim = pybamm.Simulation(model)

t_eval = np.linspace(0, 1, 100)
sim.solve(t_eval)

a_eval = sim.solution["a"](t_eval)
b_eval = sim.solution["b"](t_eval)
c_eval = sim.solution["c"](t_eval)
d_eval = sim.solution["d"](t_eval)
something = sim.solution["Something"](t_eval)

difference = something - f(a_eval, b_eval, c_eval)

fig, ax = plt.subplots(2, 1, figsize=(10, 5), sharex=True)

ax[0].plot(t_eval, f(a_eval, b_eval, c_eval), label="Original")
ax[0].plot(t_eval, something, label="Interpolated")
ax[0].set_ylabel("Value")
ax[0].legend()

ax[1].plot(t_eval, np.abs(f(a_eval, b_eval, c_eval) - something), label="Original")
ax[1].set_ylabel("Difference")

ax[-1].set_xlabel("Time [s]")
for a in ax:
a.grid()

plt.show()
100 changes: 95 additions & 5 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@

class Interpolant(pybamm.Function):
"""
Interpolate data in 1D.
Interpolate data in 1D, 2D, or 3D. Interpolation in 3D requires the input data to be
on a regular grid (as per scipy.interpolate.RegularGridInterpolator).

Parameters
----------
x : iterable of :class:`numpy.ndarray`
1-D array(s) of real values defining the data point coordinates.
The data point coordinates. If 1-D, then this is an array(s) of real values. If,
2D or 3D interpolation, then this is to ba a tuple of 1D arrays (one for each
dimension) which together define the coordinates of the points.
y : :class:`numpy.ndarray`
The values of the function to interpolate at the data points.
The values of the function to interpolate at the data points. In 2D and 3D, this
should be a matrix of two and three dimensions respectively.
children : iterable of :class:`pybamm.Symbol`
Node(s) to use when evaluating the interpolant. Each child corresponds to an
entry of x
Expand All @@ -26,10 +30,13 @@ class Interpolant(pybamm.Function):
function" is given.
interpolator : str, optional
Which interpolator to use. Can be "linear", "cubic", or "pchip". Default is
"linear".
"linear". For 3D interpolation, only "linear" an "cubic" are currently
supported.
extrapolate : bool, optional
Whether to extrapolate for points that are outside of the parametrisation
range, or return NaN (following default behaviour from scipy). Default is True.
Generally, it is best to set this to be False for 3D interpolation due to
the higher potential for errors in extrapolation.

**Extends**: :class:`pybamm.Function`
"""
Expand Down Expand Up @@ -71,6 +78,26 @@ def __init__(
"len(x2) should equal y=shape[0], "
f"but x2.shape={x2.shape} and y.shape={y.shape}"
)
elif isinstance(x, (tuple, list)) and len(x) == 3:
x1, x2, x3 = x
if y.ndim != 3:
raise ValueError("y should be three-dimensional if len(x)=3")

if x1.shape[0] != y.shape[0]:
raise ValueError(
"len(x1) should equal y=shape[0], "
f"but x1.shape={x1.shape} and y.shape={y.shape}"
)
if x2 is not None and x2.shape[0] != y.shape[1]:
raise ValueError(
"len(x2) should equal y=shape[1], "
f"but x2.shape={x2.shape} and y.shape={y.shape}"
)
if x3 is not None and x3.shape[0] != y.shape[2]:
raise ValueError(
"len(x3) should equal y=shape[2], "
f"but x3.shape={x3.shape} and y.shape={y.shape}"
)
else:
if isinstance(x, (tuple, list)):
x1 = x[0]
Expand Down Expand Up @@ -129,6 +156,28 @@ def __init__(
interpolating_function = interpolate.interp2d(
x1, x2, y, kind=interpolator
)
elif len(x) == 3:
self.dimension = 3

if extrapolate:
fill_value = None
else:
fill_value = np.nan

possible_interpolators = ["linear", "cubic"]
if interpolator not in possible_interpolators:
raise ValueError(
"""interpolator should be 'linear' or 'cubic'
for 3D interpolation"""
)
else:
interpolating_function = interpolate.RegularGridInterpolator(
(x1, x2, x3),
y,
method=interpolator,
bounds_error=False,
fill_value=fill_value,
)
else:
raise ValueError("Invalid dimension of x: {0}".format(len(x)))

Expand Down Expand Up @@ -197,7 +246,48 @@ def _function_evaluate(self, evaluated_children):
if res.ndim > 1:
return np.diagonal(res)[:, np.newaxis]
else:
# raise ValueError("Invalid children dimension: {0}".format(res.ndim))
return res[:, np.newaxis]
elif self.dimension == 3:

# If the children are scalars, we need to add a dimension
shapes = []
for child in evaluated_children:
if isinstance(child, (float, int)):
shapes.append(())
else:
shapes.append(child.shape)
shapes = set(shapes)
shapes.discard(())

if len(shapes) > 1:
raise ValueError(
"All children must have the same shape for 3D interpolation"
)

if len(shapes) == 0:
shape = (1,)
else:
shape = shapes.pop()
new_evaluated_children = []
for child in evaluated_children:

if hasattr(child, "shape") and child.shape == shape:
new_evaluated_children.append(child.flatten())
else:
new_evaluated_children.append(np.reshape(child, shape).flatten())

# return nans if there are any within the children
nans = np.isnan(new_evaluated_children)
if np.any(nans):
nan_children = []
for child, interp_range in zip(
new_evaluated_children, self.function.grid
):
nan_children.append(np.ones_like(child) * interp_range.mean())
return self.function(np.transpose(nan_children)) * np.nan
else:
res = self.function(np.transpose(new_evaluated_children))
return res[:, np.newaxis]

else: # pragma: no cover
raise ValueError("Invalid dimension: {0}".format(self.dimension))
2 changes: 1 addition & 1 deletion pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
)(*converted_children)
elif len(converted_children) == 2:
elif len(converted_children) in [2, 3]:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
)
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,102 @@ def test_interpolation_2_x_2d_y(self):
interp.evaluate(y=np.array([0, 0])), 0, decimal=3
)

def test_interpolation_3_x(self):
def f(x, y, z):
return 2 * x**3 + 3 * y**2 - z

x = np.linspace(1, 4, 11)
y = np.linspace(4, 7, 22)
z = np.linspace(7, 9, 33)
xg, yg, zg = np.meshgrid(x, y, z, indexing="ij", sparse=True)
data = f(xg, yg, zg)

var1 = pybamm.StateVector(slice(0, 1))
var2 = pybamm.StateVector(slice(1, 2))
var3 = pybamm.StateVector(slice(2, 3))

x_in = (x, y, z)
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="linear"
)

value = interp.evaluate(y=np.array([1, 5, 8]))
np.testing.assert_equal(value, f(1, 5, 8))

value = interp.evaluate(y=np.array([[1, 1, 1], [5, 4, 4], [8, 7, 7]]))
np.testing.assert_array_equal(
value, np.array([[f(1, 5, 8)], [f(1, 4, 7)], [f(1, 4, 7)]])
)

# check also works for cubic
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="cubic"
)
value = interp.evaluate(y=np.array([1, 5, 8]))
np.testing.assert_equal(value, f(1, 5, 8))

# Test raising error if data is not 3D
data_4d = np.zeros((11, 22, 33, 5))
with self.assertRaisesRegex(ValueError, "y should be three-dimensional"):
interp = pybamm.Interpolant(
x_in, data_4d, (var1, var2, var3), interpolator="linear"
)

# Test raising error if wrong shapes
with self.assertRaisesRegex(ValueError, "x1.shape"):
interp = pybamm.Interpolant(
x_in, np.zeros((12, 22, 33)), (var1, var2, var3), interpolator="linear"
)

with self.assertRaisesRegex(ValueError, "x2.shape"):
interp = pybamm.Interpolant(
x_in, np.zeros((11, 23, 33)), (var1, var2, var3), interpolator="linear"
)

with self.assertRaisesRegex(ValueError, "x3.shape"):
interp = pybamm.Interpolant(
x_in, np.zeros((11, 22, 34)), (var1, var2, var3), interpolator="linear"
)

# Raise error if not linear
with self.assertRaisesRegex(
ValueError, "interpolator should be 'linear' or 'cubic'"
):
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="pchip"
)

# Check returns nan if extrapolate set to False
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="linear", extrapolate=False
)
value = interp.evaluate(y=np.array([0, 0, 0]))
np.testing.assert_equal(value, np.nan)

# Check testing for shape works (i.e. using nans)
interp = pybamm.Interpolant(
x_in, data, (var1, var2, var3), interpolator="cubic"
)
interp.test_shape()

# test with inconsistent children shapes
# (this can occur is one child is a scaler and the others
# are vaiables)
evaluated_children = [np.array([[1]]), 4, np.array([[7]])]
value = interp._function_evaluate(evaluated_children)

evaluated_children = [np.array([[1]]), np.ones(()) * 4, np.array([[7]])]
value = interp._function_evaluate(evaluated_children)

# Test evaluation fails with different child shapes
with self.assertRaisesRegex(ValueError, "All children must"):
evaluated_children = [np.array([[1, 1]]), np.ones(()) * 4, np.array([[7]])]
value = interp._function_evaluate(evaluated_children)

# Test runs when all children are scalsrs
evaluated_children = [1, 4, 7]
value = interp._function_evaluate(evaluated_children)

def test_name(self):
a = pybamm.Symbol("a")
x = np.linspace(0, 1, 200)
Expand Down
Loading