From 1799b0f598cd8b36a11f7e15e90dea4675cfd22f Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 19 Jul 2022 11:36:54 -0400 Subject: [PATCH 1/4] fix shape for 2d interpolant --- pybamm/expression_tree/interpolant.py | 25 +++++++++++-------- .../test_expression_tree/test_interpolant.py | 8 ++++-- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 74870749b2..aaa9022cea 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -55,6 +55,16 @@ def __init__( x1, x2 = x if y.ndim != 2: raise ValueError("y should be two-dimensional if len(x)=2") + if x1.shape[0] != y.shape[1]: + raise ValueError( + "len(x1) should equal y=shape[1], " + f"but x1.shape={x1.shape} and y.shape={y.shape}" + ) + if x2 is not None and x2.shape[0] != y.shape[0]: + raise ValueError( + "len(x2) should equal y=shape[2], " + f"but x2.shape={x2.shape} and y.shape={y.shape}" + ) else: interpolator = interpolator or "cubic spline" if isinstance(x, (tuple, list)): @@ -63,17 +73,12 @@ def __init__( x1 = x x = [x] x2 = None + if x1.shape[0] != y.shape[0]: + raise ValueError( + "len(x1) should equal y=shape[0], " + f"but x1.shape={x1.shape} and y.shape={y.shape}" + ) - if x1.shape[0] != y.shape[0]: - raise ValueError( - "len(x1) should equal y=shape[0], " - "but x1.shape={} and y.shape={}".format(x1.shape, y.shape) - ) - if x2 is not None and x2.shape[0] != y.shape[1]: - raise ValueError( - "len(x2) should equal y=shape[1], " - "but x2.shape={} and y.shape={}".format(x2.shape, y.shape) - ) if isinstance(children, pybamm.Symbol): children = [children] # Either a single x is provided and there is one child diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index fec8e1a6a3..94e6155880 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -13,7 +13,11 @@ def test_errors(self): pybamm.Interpolant(np.ones(10), np.ones(11), pybamm.Symbol("a")) with self.assertRaisesRegex(ValueError, "x2"): pybamm.Interpolant( - (np.ones(10), np.ones(11)), np.ones((10, 12)), pybamm.Symbol("a") + (np.ones(12), np.ones(11)), np.ones((10, 12)), pybamm.Symbol("a") + ) + with self.assertRaisesRegex(ValueError, "x1"): + pybamm.Interpolant( + (np.ones(11), np.ones(10)), np.ones((10, 12)), pybamm.Symbol("a") ) with self.assertRaisesRegex(ValueError, "y should"): pybamm.Interpolant( @@ -77,7 +81,7 @@ def test_interpolation_1_x_2d_y(self): ) def test_interpolation_2_x_2d_y(self): - x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.05)) + x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.01)) xx, yy = np.meshgrid(x[0], x[1]) z = np.sin(xx ** 2 + yy ** 2) var1 = pybamm.StateVector(slice(0, 1)) From 13e2ce9a2c7aa9adf8cd0ece4660c15a5e74a581 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 19 Jul 2022 14:19:40 -0400 Subject: [PATCH 2/4] coverage --- tests/unit/test_expression_tree/test_interpolant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 94e6155880..1310714dbd 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -33,7 +33,7 @@ def test_errors(self): ) with self.assertRaisesRegex(ValueError, "should equal"): pybamm.Interpolant( - (np.ones(10), np.ones(12)), np.ones((10, 12)), pybamm.Symbol("a") + (np.ones(12), np.ones(10)), np.ones((10, 12)), pybamm.Symbol("a") ) with self.assertRaisesRegex(ValueError, "interpolator should be 'linear'"): pybamm.Interpolant( From 698567160e35c1a61e108e3e9ae395031cb49551 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 19 Jul 2022 14:24:51 -0400 Subject: [PATCH 3/4] update error text --- pybamm/expression_tree/interpolant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index aaa9022cea..78f0cceab8 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -62,7 +62,7 @@ def __init__( ) if x2 is not None and x2.shape[0] != y.shape[0]: raise ValueError( - "len(x2) should equal y=shape[2], " + "len(x2) should equal y=shape[0], " f"but x2.shape={x2.shape} and y.shape={y.shape}" ) else: From 0f67a53009bf47f0a38e296da0c341b52b609a0e Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Wed, 20 Jul 2022 11:52:00 -0400 Subject: [PATCH 4/4] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8cf3965c1..564a90bbac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ ## Bug fixes +- Fixed 2D intepolant ([#2180](https://github.com/pybamm-team/PyBaMM/pull/2180)) - Fixes a bug where the SPMe always builds even when `build=False` ([#2169](https://github.com/pybamm-team/PyBaMM/pull/2169)) - Some events have been removed in the case where they are constant, i.e. can never be reached ([#2158](https://github.com/pybamm-team/PyBaMM/pull/2158)) - Raise explicit `NotImplementedError` if trying to call `bool()` on a pybamm Symbol (e.g. in an if statement condition) ([#2141](https://github.com/pybamm-team/PyBaMM/pull/2141))