Skip to content

Commit

Permalink
#2382 working on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Oct 20, 2022
1 parent f4dd305 commit 5cdc7c5
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 56 deletions.
4 changes: 2 additions & 2 deletions pybamm/models/base_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def __init__(
boundary_conditions=pybamm.ReadOnlyDict(boundary_conditions),
# Variables is initially empty, but will be filled in when variables are
# called
variables=_OnTheFlyUpdatedDict(
variables=_OnTheFlyUpdateDict(
unprocessed_variables, self.variables_update_function
),
events=events,
Expand All @@ -428,7 +428,7 @@ def rhs(self, value):
raise AttributeError(f"Attributes of {self} are read-only")


class _OnTheFlyUpdatedDict(dict):
class _OnTheFlyUpdateDict(dict):
"""
A dictionary that updates itself when a key is called.
"""
Expand Down
35 changes: 13 additions & 22 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def _build_model(self):

self._equations.build_model_equations(self)

def update(self, *submodels):
self._equations.update(*submodels)

def set_initial_conditions_from(self, solution, inplace=True, return_type="model"):
"""
Update initial conditions with the final states from a Solution object or from
Expand Down Expand Up @@ -355,26 +358,10 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model
# Also update the concatenated initial conditions if the model is already
# discretised
if self.is_discretised:
# Unpack slices for sorting
y_slices = {var: slce for var, slce in self.y_slices.items()}
slices = []
for symbol in self.initial_conditions.keys():
if isinstance(symbol, pybamm.Concatenation):
# must append the slice for the whole concatenation, so that
# equations get sorted correctly
slices.append(
slice(
y_slices[symbol.children[0]][0].start,
y_slices[symbol.children[-1]][0].stop,
)
)
else:
slices.append(y_slices[symbol][0])
equations = list(initial_conditions.values())
# sort equations according to slices
sorted_equations = [eq for _, eq in sorted(zip(slices, equations))]
concatenated_initial_conditions = pybamm.NumpyConcatenation(
*sorted_equations
concatenated_initial_conditions = (
self._equations._discretisation._concatenate_in_order(
initial_conditions
)
)
else:
concatenated_initial_conditions = None
Expand All @@ -385,8 +372,12 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model
else:
model = self.new_copy()

model.initial_conditions = initial_conditions
model.concatenated_initial_conditions = concatenated_initial_conditions
model._equations._initial_conditions = pybamm.ReadOnlyDict(
initial_conditions
)
model._equations._concatenated_initial_conditions = (
concatenated_initial_conditions
)
return model
elif return_type == "ics":
return initial_conditions, concatenated_initial_conditions
Expand Down
2 changes: 1 addition & 1 deletion pybamm/models/symbolic_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def check_and_combine_dict(self, dict1, dict2):

class _EquationDict(dict):
def __init__(self, name, equations):
name = name
self.name = name
equations = self.check_and_convert_equations(equations)
super().__init__(equations)

Expand Down
8 changes: 4 additions & 4 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def check_parameter_values(self, values):
"by 3600 to get the same results as before."
)

def process_model(self, unprocessed_model, inplace=None):
def process_model(self, unprocessed_model, inplace=True):
"""Assign parameter values to a model.
Currently inplace, could be changed to return a new model.
Expand Down Expand Up @@ -447,7 +447,7 @@ def process_model(self, unprocessed_model, inplace=None):
new_scale = self.process_symbol(scale)
new_length_scales[domain] = new_scale

parameterized_equations = pybamm._ParameterisedEquations(
parameterised_equations = pybamm._ParameterisedEquations(
self,
new_rhs,
new_algebraic,
Expand All @@ -463,10 +463,10 @@ def process_model(self, unprocessed_model, inplace=None):
# inplace vs not inplace
if inplace:
model = unprocessed_model
model._equations = parameterized_equations
model._equations = parameterised_equations
else:
# create a copy of the model
model = unprocessed_model.new_copy(equations=parameterized_equations)
model = unprocessed_model.new_copy(equations=parameterised_equations)

pybamm.logger.info("Finish setting parameters for {}".format(model.name))

Expand Down
12 changes: 6 additions & 6 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def _set_up_model_sensitivities_inplace(
# First, we reset the mass matrix and bounds back to their original form
# if they have been extended
if model.bounds[0].shape[0] > model.len_rhs_and_alg:
model.bounds = (
model._equations._bounds = (
model.bounds[0][: model.len_rhs_and_alg],
model.bounds[1][: model.len_rhs_and_alg],
)
Expand All @@ -389,10 +389,10 @@ def _set_up_model_sensitivities_inplace(
and model.mass_matrix.shape[0] > model.len_rhs_and_alg
):
if model.mass_matrix_inv is not None:
model.mass_matrix_inv = pybamm.Matrix(
model._equations._mass_matrix_inv = pybamm.Matrix(
model.mass_matrix_inv.entries[: model.len_rhs, : model.len_rhs]
)
model.mass_matrix = pybamm.Matrix(
model._equations._mass_matrix = pybamm.Matrix(
model.mass_matrix.entries[
: model.len_rhs_and_alg, : model.len_rhs_and_alg
]
Expand All @@ -406,7 +406,7 @@ def _set_up_model_sensitivities_inplace(
elif model.len_alg != 0:
n_inputs = model.len_alg_sens // model.len_alg
if model.bounds[0].shape[0] == model.len_rhs_and_alg:
model.bounds = (
model._equations._bounds = (
np.repeat(model.bounds[0], n_inputs + 1),
np.repeat(model.bounds[1], n_inputs + 1),
)
Expand All @@ -416,13 +416,13 @@ def _set_up_model_sensitivities_inplace(
):

if model.mass_matrix_inv is not None:
model.mass_matrix_inv = pybamm.Matrix(
model._equations._mass_matrix_inv = pybamm.Matrix(
block_diag(
[model.mass_matrix_inv.entries] * (n_inputs + 1),
format="csr",
)
)
model.mass_matrix = pybamm.Matrix(
model._equations._mass_matrix = pybamm.Matrix(
block_diag(
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
)
Expand Down
3 changes: 3 additions & 0 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __len__(self):
def __iter__(self):
return iter(self._items)

def copy(self):
return ReadOnlyDict(self._items.copy())


class Timer(object):
"""
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_adding_1D_external_variable(self):

self.assertEqual(disc.y_slices[a][0], slice(0, 10, None))

self.assertEqual(model.y_slices[a][0], slice(0, 10, None))
self.assertEqual(model._equations._y_slices[a][0], slice(0, 10, None))
self.assertEqual(model.bounds, disc.bounds)

b_test = np.ones((10, 1))
Expand All @@ -138,8 +138,8 @@ def test_adding_1D_external_variable(self):
)

# check that b is added to the boundary conditions
model.bcs[b]["left"]
model.bcs[b]["right"]
disc.bcs[b]["left"]
disc.bcs[b]["right"]

# check that grad and div(grad ) produce the correct shapes
self.assertEqual(model.variables["b"].shape_for_testing, (10, 1))
Expand Down Expand Up @@ -205,8 +205,8 @@ def test_concatenation_external_variables(self):
)

# check that b is added to the boundary conditions
model.bcs[b]["left"]
model.bcs[b]["right"]
disc.bcs[b]["left"]
disc.bcs[b]["right"]

# check that grad and div(grad ) produce the correct shapes
self.assertEqual(model.variables["b"].shape_for_testing, (15, 1))
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_models/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_check_no_repeated_keys(self):
var = pybamm.Variable("var")
model.algebraic = {var: var}
with self.assertRaisesRegex(pybamm.ModelError, "Multiple equations specified"):
model.check_no_repeated_keys()
model._equations.check_no_repeated_keys()

def test_check_well_posedness_variables(self):
# Well-posed ODE model
Expand Down Expand Up @@ -1009,9 +1009,9 @@ def get_coupled_variables(self, variables):
"submodel 1": Submodel1(None, "negative"),
"submodel 2": Submodel2(None, "negative"),
}
self.assertFalse(model._built)
self.assertFalse(model._equations._built)
model.build_model()
self.assertTrue(model._built)
self.assertTrue(model._equations._built)
u = model.variables["u"]
v = model.variables["v"]
self.assertEqual(model.rhs[u].value, 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_ode_solver_fail_with_dae(self):
model = pybamm.BaseModel()
a = pybamm.Scalar(1)
model.algebraic = {a: a}
model.concatenated_initial_conditions = pybamm.Scalar(0)
model.initial_conditions = {a: 0}
solver = pybamm.ScipySolver()
with self.assertRaisesRegex(pybamm.SolverError, "Cannot use ODE solver"):
solver.set_up(model)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def test_dae_solver_algebraic_model(self):
np.testing.assert_array_equal(solution.y, -1)

# change initial_conditions and re-solve (to test if ics_only works)
model.concatenated_initial_conditions = pybamm.Vector(np.array([[1]]))
model._equations._concatenated_initial_conditions = pybamm.Vector(np.array([[1]]))
solution = solver.solve(model, t_eval)
np.testing.assert_array_equal(solution.y, -1)

Expand Down
9 changes: 0 additions & 9 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,15 +915,6 @@ def nonsmooth_rate(t):
)
np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06)

def test_ode_solver_fail_with_dae(self):
model = pybamm.BaseModel()
a = pybamm.Scalar(1)
model.algebraic = {a: a}
model.concatenated_initial_conditions = a
solver = pybamm.ScikitsOdeSolver()
with self.assertRaisesRegex(pybamm.SolverError, "Cannot use ODE solver"):
solver.set_up(model)

def test_dae_solver_algebraic_model(self):
model = pybamm.BaseModel()
var = pybamm.Variable("var")
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_model_solver_multiple_inputs_discontinuity_error(self):
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]

model.events = [
model._equations._events = [
pybamm.Event(
"discontinuity",
pybamm.Scalar(t_eval[-1] / 2),
Expand Down Expand Up @@ -495,7 +495,7 @@ def test_model_solver_manually_update_initial_conditions(self):
)

# Change initial conditions and solve again
model.concatenated_initial_conditions = pybamm.NumpyConcatenation(
model._equations._concatenated_initial_conditions = pybamm.NumpyConcatenation(
pybamm.Vector([[2]])
)
solution = solver.solve(model, t_eval)
Expand Down

0 comments on commit 5cdc7c5

Please sign in to comment.