Skip to content

Commit

Permalink
Merge pull request #2180 from pybamm-team/fix-2d-interpolant
Browse files Browse the repository at this point in the history
fix shape for 2d interpolant
  • Loading branch information
valentinsulzer authored Jul 20, 2022
2 parents 33a4037 + 0f67a53 commit 41bf72b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

## Bug fixes

- Fixed 2D intepolant ([#2180](https://github.com/pybamm-team/PyBaMM/pull/2180))
- Fixes a bug where the SPMe always builds even when `build=False` ([#2169](https://github.com/pybamm-team/PyBaMM/pull/2169))
- Some events have been removed in the case where they are constant, i.e. can never be reached ([#2158](https://github.com/pybamm-team/PyBaMM/pull/2158))
- Raise explicit `NotImplementedError` if trying to call `bool()` on a pybamm Symbol (e.g. in an if statement condition) ([#2141](https://github.com/pybamm-team/PyBaMM/pull/2141))
Expand Down
25 changes: 15 additions & 10 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def __init__(
x1, x2 = x
if y.ndim != 2:
raise ValueError("y should be two-dimensional if len(x)=2")
if x1.shape[0] != y.shape[1]:
raise ValueError(
"len(x1) should equal y=shape[1], "
f"but x1.shape={x1.shape} and y.shape={y.shape}"
)
if x2 is not None and x2.shape[0] != y.shape[0]:
raise ValueError(
"len(x2) should equal y=shape[0], "
f"but x2.shape={x2.shape} and y.shape={y.shape}"
)
else:
interpolator = interpolator or "cubic spline"
if isinstance(x, (tuple, list)):
Expand All @@ -63,17 +73,12 @@ def __init__(
x1 = x
x = [x]
x2 = None
if x1.shape[0] != y.shape[0]:
raise ValueError(
"len(x1) should equal y=shape[0], "
f"but x1.shape={x1.shape} and y.shape={y.shape}"
)

if x1.shape[0] != y.shape[0]:
raise ValueError(
"len(x1) should equal y=shape[0], "
"but x1.shape={} and y.shape={}".format(x1.shape, y.shape)
)
if x2 is not None and x2.shape[0] != y.shape[1]:
raise ValueError(
"len(x2) should equal y=shape[1], "
"but x2.shape={} and y.shape={}".format(x2.shape, y.shape)
)
if isinstance(children, pybamm.Symbol):
children = [children]
# Either a single x is provided and there is one child
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def test_errors(self):
pybamm.Interpolant(np.ones(10), np.ones(11), pybamm.Symbol("a"))
with self.assertRaisesRegex(ValueError, "x2"):
pybamm.Interpolant(
(np.ones(10), np.ones(11)), np.ones((10, 12)), pybamm.Symbol("a")
(np.ones(12), np.ones(11)), np.ones((10, 12)), pybamm.Symbol("a")
)
with self.assertRaisesRegex(ValueError, "x1"):
pybamm.Interpolant(
(np.ones(11), np.ones(10)), np.ones((10, 12)), pybamm.Symbol("a")
)
with self.assertRaisesRegex(ValueError, "y should"):
pybamm.Interpolant(
Expand All @@ -29,7 +33,7 @@ def test_errors(self):
)
with self.assertRaisesRegex(ValueError, "should equal"):
pybamm.Interpolant(
(np.ones(10), np.ones(12)), np.ones((10, 12)), pybamm.Symbol("a")
(np.ones(12), np.ones(10)), np.ones((10, 12)), pybamm.Symbol("a")
)
with self.assertRaisesRegex(ValueError, "interpolator should be 'linear'"):
pybamm.Interpolant(
Expand Down Expand Up @@ -77,7 +81,7 @@ def test_interpolation_1_x_2d_y(self):
)

def test_interpolation_2_x_2d_y(self):
x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.05))
x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.01))
xx, yy = np.meshgrid(x[0], x[1])
z = np.sin(xx ** 2 + yy ** 2)
var1 = pybamm.StateVector(slice(0, 1))
Expand Down

0 comments on commit 41bf72b

Please sign in to comment.