Skip to content

Commit

Permalink
#3959 revert to iterative solver for interpolator and relax test tole…
Browse files Browse the repository at this point in the history
…rances
  • Loading branch information
brosaplanella committed May 20, 2024
1 parent 3abc482 commit fc3aa8e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
10 changes: 0 additions & 10 deletions pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,12 @@ def __init__(
fill_value = None
else:
fill_value = np.nan
if interpolator == "cubic":
solver = spsolve
else:
solver = None
interpolating_function = interpolate.RegularGridInterpolator(
(x1, x2),
y,
method=interpolator,
bounds_error=False,
fill_value=fill_value,
solver=solver,
)

elif len(x) == 3:
Expand All @@ -179,17 +174,12 @@ def __init__(
for 3D interpolation"""
)
else:
if interpolator == "cubic":
solver = spsolve
else:
solver = None
interpolating_function = interpolate.RegularGridInterpolator(
(x1, x2, x3),
y,
method=interpolator,
bounds_error=False,
fill_value=fill_value,
solver=solver,
)
else:
raise ValueError(f"Invalid dimension of x: {len(x)}")
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_expression_tree/test_interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def f(x, y):
# check also works for cubic
interp = pybamm.Interpolant(x_in, data, (var1, var2), interpolator="cubic")
value = interp.evaluate(y=np.array([1, 5]))
np.testing.assert_equal(value, f(1, 5))
# relaxed tolerance as from Scipy 1.13 it uses iterative solver
np.testing.assert_almost_equal(value, f(1, 5), decimal=3)

# Test raising error if data is not 2D
data_3d = np.zeros((11, 22, 33))
Expand Down Expand Up @@ -231,7 +232,8 @@ def f(x, y, z):
x_in, data, (var1, var2, var3), interpolator="cubic"
)
value = interp.evaluate(y=np.array([1, 5, 8]))
np.testing.assert_equal(value, f(1, 5, 8))
# relaxed tolerance as from Scipy 1.13 it uses iterative solver
np.testing.assert_almost_equal(value, f(1, 5, 8), decimal=3)

# Test raising error if data is not 3D
data_4d = np.zeros((11, 22, 33, 5))
Expand Down

0 comments on commit fc3aa8e

Please sign in to comment.