diff --git a/CHANGELOG.md b/CHANGELOG.md index e2ec657d45..7724fa0cc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Features +- `plot` and `plot2D` now take and return a matplotlib Axis to allow for easier customization ([#1472](https://github.com/pybamm-team/PyBaMM/pull/1472)) +- `ParameterValues.evaluate` can now return arrays to allow function parameters to be easily evaluated ([#1472](https://github.com/pybamm-team/PyBaMM/pull/1472)) - Added Batch Study class ([#1455](https://github.com/pybamm-team/PyBaMM/pull/1455)) - Added `ConcatenationVariable`, which is automatically created when variables are concatenated ([#1453](https://github.com/pybamm-team/PyBaMM/pull/1453)) - Added "fast with events" mode for the CasADi solver, which solves a model and finds events more efficiently than "safe" mode. As of PR #1450 this feature is still being tested and "safe" mode remains the default ([#1450](https://github.com/pybamm-team/PyBaMM/pull/1450)) diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index 6c6555b7f3..f592ba7112 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -776,14 +776,14 @@ def evaluate(self, symbol): Returns ------- - number of array + number or array The evaluated symbol """ processed_symbol = self.process_symbol(symbol) - if processed_symbol.evaluates_to_constant_number(): + if processed_symbol.is_constant(): return processed_symbol.evaluate() else: - raise ValueError("symbol must evaluate to a constant scalar") + raise ValueError("symbol must evaluate to a constant scalar or array") def _ipython_key_completions_(self): return list(self._dict_items.keys()) diff --git a/pybamm/plotting/plot.py b/pybamm/plotting/plot.py index 7734ac2b74..19aa9dc5e0 100644 --- a/pybamm/plotting/plot.py +++ b/pybamm/plotting/plot.py @@ -5,7 +5,7 @@ from .quick_plot import ax_min, ax_max -def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs): +def plot(x, y, ax=None, testing=False, **kwargs): """ Generate a simple 1D plot. Calls `matplotlib.pyplot.plot` with keyword arguments 'kwargs'. For a list of 'kwargs' see the @@ -17,10 +17,8 @@ def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs): The array to plot on the x axis y : :class:`pybamm.Array` The array to plot on the y axis - xlabel : str, optional - The label for the x axis - ylabel : str, optional - The label for the y axis + ax : matplotlib Axis, optional + The axis on which to put the plot. If None, a new figure and axis is created. testing : bool, optional Whether to actually make the plot (turned off for unit tests) kwargs @@ -34,13 +32,15 @@ def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs): if not isinstance(y, pybamm.Array): raise TypeError("y must be 'pybamm.Array'") - plt.plot(x.entries, y.entries, **kwargs) - plt.ylim([ax_min(y.entries), ax_max(y.entries)]) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.title(title) + if ax is not None: + testing = True + else: + _, ax = plt.subplots() + + ax.plot(x.entries, y.entries, **kwargs) + ax.set_ylim([ax_min(y.entries), ax_max(y.entries)]) if not testing: # pragma: no cover plt.show() - return + return ax diff --git a/pybamm/plotting/plot2D.py b/pybamm/plotting/plot2D.py index 340379e196..80bb5d0ee2 100644 --- a/pybamm/plotting/plot2D.py +++ b/pybamm/plotting/plot2D.py @@ -5,7 +5,7 @@ from .quick_plot import ax_min, ax_max -def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwargs): +def plot2D(x, y, z, ax=None, testing=False, **kwargs): """ Generate a simple 2D plot. Calls `matplotlib.pyplot.contourf` with keyword arguments 'kwargs'. For a list of 'kwargs' see the @@ -19,12 +19,8 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg The array to plot on the y axis. Can be of shape (M, N) or (M, 1) z : :class:`pybamm.Array` The array to plot on the z axis. Is of shape (M, N) - xlabel : str, optional - The label for the x axis - ylabel : str, optional - The label for the y axis - title : str, optional - The title for the plot + ax : matplotlib Axis, optional + The axis on which to put the plot. If None, a new figure and axis is created. testing : bool, optional Whether to actually make the plot (turned off for unit tests) @@ -38,6 +34,11 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg if not isinstance(z, pybamm.Array): raise TypeError("z must be 'pybamm.Array'") + if ax is not None: + testing = True + else: + _, ax = plt.subplots() + # Get correct entries of x and y depending on shape if x.shape == y.shape == z.shape: x_entries = x.entries @@ -46,7 +47,7 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg x_entries = x.entries[:, 0] y_entries = y.entries[:, 0] - plt.contourf( + plot = ax.contourf( x_entries, y_entries, z.entries, @@ -54,12 +55,9 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg vmax=ax_max(z.entries), **kwargs ) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.title(title) - plt.colorbar() + plt.colorbar(plot, ax=ax) if not testing: # pragma: no cover plt.show() - return + return ax diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 289a5e0795..1787c849c7 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -756,13 +756,14 @@ def test_evaluate(self): c = pybamm.Parameter("c") self.assertEqual(parameter_values.evaluate(a), 1) self.assertEqual(parameter_values.evaluate(a + (b * c)), 7) + d = pybamm.Parameter("a") + pybamm.Parameter("b") * pybamm.Array([4, 5]) + np.testing.assert_array_equal( + parameter_values.evaluate(d), np.array([9, 11])[:, np.newaxis] + ) y = pybamm.StateVector(slice(0, 1)) with self.assertRaises(ValueError): parameter_values.evaluate(y) - array = pybamm.Array(np.array([1, 2, 3])) - with self.assertRaises(ValueError): - parameter_values.evaluate(array) def test_export_csv(self): def some_function(self): diff --git a/tests/unit/test_plotting/test_plot.py b/tests/unit/test_plotting/test_plot.py index 445095dc02..e30740db61 100644 --- a/tests/unit/test_plotting/test_plot.py +++ b/tests/unit/test_plotting/test_plot.py @@ -1,13 +1,18 @@ import pybamm import unittest import numpy as np +import matplotlib.pyplot as plt class TestPlot(unittest.TestCase): def test_plot(self): x = pybamm.Array(np.array([0, 3, 10])) y = pybamm.Array(np.array([6, 16, 78])) - pybamm.plot(x, y, xlabel="x", ylabel="y", title="title", testing=True) + pybamm.plot(x, y, testing=True) + + _, ax = plt.subplots() + ax_out = pybamm.plot(x, y, ax=ax, testing=True) + self.assertEqual(ax_out, ax) def test_plot_fail(self): x = pybamm.Array(np.array([0])) @@ -22,10 +27,14 @@ def test_plot2D(self): X, Y = pybamm.meshgrid(x, y) # plot with array directly - pybamm.plot2D(x, y, Y, xlabel="x", ylabel="y", title="title", testing=True) + pybamm.plot2D(x, y, Y, testing=True) # plot with meshgrid - pybamm.plot2D(X, Y, Y, xlabel="x", ylabel="y", title="title", testing=True) + pybamm.plot2D(X, Y, Y, testing=True) + + _, ax = plt.subplots() + ax_out = pybamm.plot2D(X, Y, Y, ax=ax, testing=True) + self.assertEqual(ax_out, ax) def test_plot2D_fail(self): x = pybamm.Array(np.array([0]))