Skip to content

Commit

Permalink
Merge pull request #1472 from pybamm-team/issue-1465-evaluate-functio…
Browse files Browse the repository at this point in the history
…n-parameters

#1465 allow evaluate parameter to return arrays
  • Loading branch information
rtimms authored Apr 29, 2021
2 parents b4735e1 + cb04f46 commit a31e209
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 33 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Features

- `plot` and `plot2D` now take and return a matplotlib Axis to allow for easier customization ([#1472](https://github.com/pybamm-team/PyBaMM/pull/1472))
- `ParameterValues.evaluate` can now return arrays to allow function parameters to be easily evaluated ([#1472](https://github.com/pybamm-team/PyBaMM/pull/1472))
- Added Batch Study class ([#1455](https://github.com/pybamm-team/PyBaMM/pull/1455))
- Added `ConcatenationVariable`, which is automatically created when variables are concatenated ([#1453](https://github.com/pybamm-team/PyBaMM/pull/1453))
- Added "fast with events" mode for the CasADi solver, which solves a model and finds events more efficiently than "safe" mode. As of PR #1450 this feature is still being tested and "safe" mode remains the default ([#1450](https://github.com/pybamm-team/PyBaMM/pull/1450))
Expand Down
6 changes: 3 additions & 3 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,14 +776,14 @@ def evaluate(self, symbol):
Returns
-------
number of array
number or array
The evaluated symbol
"""
processed_symbol = self.process_symbol(symbol)
if processed_symbol.evaluates_to_constant_number():
if processed_symbol.is_constant():
return processed_symbol.evaluate()
else:
raise ValueError("symbol must evaluate to a constant scalar")
raise ValueError("symbol must evaluate to a constant scalar or array")

def _ipython_key_completions_(self):
return list(self._dict_items.keys())
Expand Down
22 changes: 11 additions & 11 deletions pybamm/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .quick_plot import ax_min, ax_max


def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
def plot(x, y, ax=None, testing=False, **kwargs):
"""
Generate a simple 1D plot. Calls `matplotlib.pyplot.plot` with keyword
arguments 'kwargs'. For a list of 'kwargs' see the
Expand All @@ -17,10 +17,8 @@ def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
The array to plot on the x axis
y : :class:`pybamm.Array`
The array to plot on the y axis
xlabel : str, optional
The label for the x axis
ylabel : str, optional
The label for the y axis
ax : matplotlib Axis, optional
The axis on which to put the plot. If None, a new figure and axis is created.
testing : bool, optional
Whether to actually make the plot (turned off for unit tests)
kwargs
Expand All @@ -34,13 +32,15 @@ def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
if not isinstance(y, pybamm.Array):
raise TypeError("y must be 'pybamm.Array'")

plt.plot(x.entries, y.entries, **kwargs)
plt.ylim([ax_min(y.entries), ax_max(y.entries)])
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
if ax is not None:
testing = True
else:
_, ax = plt.subplots()

ax.plot(x.entries, y.entries, **kwargs)
ax.set_ylim([ax_min(y.entries), ax_max(y.entries)])

if not testing: # pragma: no cover
plt.show()

return
return ax
24 changes: 11 additions & 13 deletions pybamm/plotting/plot2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .quick_plot import ax_min, ax_max


def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
def plot2D(x, y, z, ax=None, testing=False, **kwargs):
"""
Generate a simple 2D plot. Calls `matplotlib.pyplot.contourf` with keyword
arguments 'kwargs'. For a list of 'kwargs' see the
Expand All @@ -19,12 +19,8 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg
The array to plot on the y axis. Can be of shape (M, N) or (M, 1)
z : :class:`pybamm.Array`
The array to plot on the z axis. Is of shape (M, N)
xlabel : str, optional
The label for the x axis
ylabel : str, optional
The label for the y axis
title : str, optional
The title for the plot
ax : matplotlib Axis, optional
The axis on which to put the plot. If None, a new figure and axis is created.
testing : bool, optional
Whether to actually make the plot (turned off for unit tests)
Expand All @@ -38,6 +34,11 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg
if not isinstance(z, pybamm.Array):
raise TypeError("z must be 'pybamm.Array'")

if ax is not None:
testing = True
else:
_, ax = plt.subplots()

# Get correct entries of x and y depending on shape
if x.shape == y.shape == z.shape:
x_entries = x.entries
Expand All @@ -46,20 +47,17 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg
x_entries = x.entries[:, 0]
y_entries = y.entries[:, 0]

plt.contourf(
plot = ax.contourf(
x_entries,
y_entries,
z.entries,
vmin=ax_min(z.entries),
vmax=ax_max(z.entries),
**kwargs
)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.colorbar()
plt.colorbar(plot, ax=ax)

if not testing: # pragma: no cover
plt.show()

return
return ax
7 changes: 4 additions & 3 deletions tests/unit/test_parameters/test_parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,13 +756,14 @@ def test_evaluate(self):
c = pybamm.Parameter("c")
self.assertEqual(parameter_values.evaluate(a), 1)
self.assertEqual(parameter_values.evaluate(a + (b * c)), 7)
d = pybamm.Parameter("a") + pybamm.Parameter("b") * pybamm.Array([4, 5])
np.testing.assert_array_equal(
parameter_values.evaluate(d), np.array([9, 11])[:, np.newaxis]
)

y = pybamm.StateVector(slice(0, 1))
with self.assertRaises(ValueError):
parameter_values.evaluate(y)
array = pybamm.Array(np.array([1, 2, 3]))
with self.assertRaises(ValueError):
parameter_values.evaluate(array)

def test_export_csv(self):
def some_function(self):
Expand Down
15 changes: 12 additions & 3 deletions tests/unit/test_plotting/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import pybamm
import unittest
import numpy as np
import matplotlib.pyplot as plt


class TestPlot(unittest.TestCase):
def test_plot(self):
x = pybamm.Array(np.array([0, 3, 10]))
y = pybamm.Array(np.array([6, 16, 78]))
pybamm.plot(x, y, xlabel="x", ylabel="y", title="title", testing=True)
pybamm.plot(x, y, testing=True)

_, ax = plt.subplots()
ax_out = pybamm.plot(x, y, ax=ax, testing=True)
self.assertEqual(ax_out, ax)

def test_plot_fail(self):
x = pybamm.Array(np.array([0]))
Expand All @@ -22,10 +27,14 @@ def test_plot2D(self):
X, Y = pybamm.meshgrid(x, y)

# plot with array directly
pybamm.plot2D(x, y, Y, xlabel="x", ylabel="y", title="title", testing=True)
pybamm.plot2D(x, y, Y, testing=True)

# plot with meshgrid
pybamm.plot2D(X, Y, Y, xlabel="x", ylabel="y", title="title", testing=True)
pybamm.plot2D(X, Y, Y, testing=True)

_, ax = plt.subplots()
ax_out = pybamm.plot2D(X, Y, Y, ax=ax, testing=True)
self.assertEqual(ax_out, ax)

def test_plot2D_fail(self):
x = pybamm.Array(np.array([0]))
Expand Down

0 comments on commit a31e209

Please sign in to comment.