From a4fedaeb85b43159ccec4a9e0c418704a4632fd1 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 11 Aug 2023 16:33:31 +0000 Subject: [PATCH 01/29] Draft a serialisation method Create to_json() functions in corresponsing classes Working basic de/serialisation --- pybamm/expression_tree/array.py | 29 +++ pybamm/expression_tree/binary_operators.py | 96 +++++++++ pybamm/expression_tree/broadcasts.py | 5 + pybamm/expression_tree/concatenations.py | 64 ++++++ pybamm/expression_tree/functions.py | 35 ++++ pybamm/expression_tree/input_parameter.py | 14 ++ pybamm/expression_tree/interpolant.py | 19 ++ pybamm/expression_tree/parameter.py | 10 + pybamm/expression_tree/scalar.py | 9 + pybamm/expression_tree/state_vector.py | 23 ++ pybamm/expression_tree/symbol.py | 14 ++ pybamm/expression_tree/unary_operators.py | 87 +++++++- pybamm/expression_tree/variable.py | 19 ++ pybamm/models/base_model.py | 44 ++++ pybamm/serialisation/serialisation.py | 232 +++++++++++++++++++++ 15 files changed, 699 insertions(+), 1 deletion(-) create mode 100644 pybamm/serialisation/serialisation.py diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a9141041b3..d0ba8d1296 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -128,6 +128,35 @@ def to_equation(self): entries_list = self.entries.tolist() return sympy.Array(entries_list) + def to_json(self): + """ + Method to serialise an Array object into JSON. + """ + + if isinstance(self.entries, np.ndarray): + matrix = self.entries.tolist() + elif isinstance(self.entries, csr_matrix): + matrix = { + "shape": self.entries.shape, + "data": self.entries.data.tolist(), + "row_indices": self.entries.indices.tolist(), + "column_pointers": self.entries.indptr.tolist(), + } + else: + raise TypeError( + f"Ah! Dense matrix! {self.entries}" + ) # PL: Double check this + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "entries": matrix, + # "entries_string": self.entries_string.decode(), + } + + return json_dict + def linspace(start, stop, num=50, **kwargs): """ diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 749384e9bc..05520a081a 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -68,6 +68,20 @@ def __init__(self, name, left, right): self.left = self.children[0] self.right = self.children[1] + @classmethod + def _from_json(cls, name, left, right, domains): + """Use to instantiate when deserialising; discretisation has + already occured so pre-processing of binaries is not necessary.""" + instance = cls.__new__(cls) + + super(BinaryOperator, instance).__init__( + name, children=[left, right], domains=domains + ) + instance.left = instance.children[0] + instance.right = instance.children[1] + + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" # Possibly add brackets for clarity @@ -155,6 +169,15 @@ def to_equation(self): eq2 = child2.to_equation() return self._sympy_operator(eq1, eq2) + def to_json(self): + """ + Method to serialise a BinaryOperator object into JSON. + """ + + json_dict = {"name": self.name, "id": self.id, "domains": self.domains} + + return json_dict + class Power(BinaryOperator): """ @@ -165,6 +188,12 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("**", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -206,6 +235,12 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("+", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) @@ -229,6 +264,12 @@ def __init__(self, left, right): super().__init__("-", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("-", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -254,6 +295,12 @@ def __init__(self, left, right): super().__init__("*", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("*", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -290,6 +337,13 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + # instance = super(MatrixMultiplication, cls)._from_json("@", left, right) + instance = super()._from_json("@", left, right, domains) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # We shouldn't need this @@ -337,6 +391,12 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("/", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule @@ -381,6 +441,12 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("inner product", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -450,6 +516,12 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("==", left, right, domains) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Equality should always be multiplied by something else so hopefully don't @@ -496,6 +568,12 @@ def __init__(self, name, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) + @classmethod + def _from_json(cls, name, left, right): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json(name, left, right) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Heaviside should always be multiplied by something else so hopefully don't @@ -561,6 +639,12 @@ class Modulo(BinaryOperator): def __init__(self, left, right): super().__init__("%", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("%", left, right, domains) + return instance + def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -599,6 +683,12 @@ class Minimum(BinaryOperator): def __init__(self, left, right): super().__init__("minimum", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("minimum", left, right, domains) + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "minimum({!s}, {!s})".format(self.left, self.right) @@ -635,6 +725,12 @@ class Maximum(BinaryOperator): def __init__(self, left, right): super().__init__("maximum", left, right) + @classmethod + def _from_json(cls, left, right, domains): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = super()._from_json("maximum", left, right, domains) + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "maximum({!s}, {!s})".format(self.left, self.right) diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 32cf2c002b..45c37a55f0 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -50,6 +50,11 @@ def _diff(self, variable): # Differentiate the child and broadcast the result in the same way return self._unary_new_copy(self.child.diff(variable)) + def to_json(self): + raise NotImplementedError( + "pybamm.Broadcast: Serialisation is only implemented for post-discretisation." # PL: Come up with a better message! + ) + class PrimaryBroadcast(Broadcast): """ diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 2185a0fad6..5e678af95f 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -43,6 +43,16 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None): super().__init__(name, children, domains=domains) + @classmethod + def _from_json(cls, *children, name, domains, concat_fun=None): + instance = cls.__new__(cls) + + super(Concatenation, instance).__init__(name, children, domains=domains) + + instance.concatenation_function = concat_fun + + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" out = self.name + "(" @@ -182,6 +192,18 @@ def __init__(self, *children): concat_fun=np.concatenate ) + @classmethod + def _from_json(cls, children, domains): + """See :meth:`pybamm.Concatenation._from_json()`.""" + instance = super()._from_json( + *children, + name="numpy_concatenation", + domains=domains, + concat_fun=np.concatenate + ) + + return instance + def _concatenation_jac(self, children_jacs): """See :meth:`pybamm.Concatenation.concatenation_jac()`.""" children = self.children @@ -250,6 +272,22 @@ def __init__(self, children, full_mesh, copy_this=None): self._children_slices = copy.copy(copy_this._children_slices) self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts + @classmethod + def _from_json( + cls, children, size, slices, children_slices, secondary_dimensions_npts, domains + ): + """See :meth:`pybamm.Concatenation._from_json()`.""" + instance = super()._from_json( + *children, name="domain_concatenation", domains=domains + ) + + instance._size = size + instance._slices = slices + instance._children_slices = children_slices + instance.secondary_dimensions_npts = secondary_dimensions_npts + + return instance + def _get_auxiliary_domain_repeats(self, auxiliary_domains): """Helper method to read the 'auxiliary_domain' meshes.""" mesh_pts = 1 @@ -315,6 +353,32 @@ def _concatenation_new_copy(self, children): ) return new_symbol + def to_json(self): + """ + Method to serialise a DomainConcatenation object into JSON. + """ + + def unpack_defaultDict(slices): + slices = dict(slices) + for domain, sls in slices.items(): + sls = [{"start": s.start, "stop": s.stop, "step": s.step} for s in sls] + slices[domain] = sls + return slices + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "slices": unpack_defaultDict(self._slices), + "size": self._size, + "children_slices": [ + unpack_defaultDict(child_slice) for child_slice in self._children_slices + ], + "secondary_dimensions_npts": self.secondary_dimensions_npts, + } + + return json_dict + class SparseStack(Concatenation): """ diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 80c2848ad9..c759cc0b51 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -211,6 +211,28 @@ def to_equation(self): eq_list.append(eq) return self._sympy_operator(*eq_list) + # PL: think I need something here. presumably I can serialise function methods using just their names, then rehydrate them at the point they're read back in? + def to_json(self): + """ + Method to serialise a Function object into JSON. + """ + + try: + func_name = self.function.__name__ + except: + raise Exception + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "function": func_name, # PL: actually put name here + "derivative": self.derivative, + "differentiated_function": self.differentiated_function, # PL: same here (although is this defined? or is it just written out...) + } + + return json_dict + def simplified_function(func_class, child): """ @@ -254,6 +276,19 @@ def _sympy_operator(self, child): sympy_function = getattr(sympy, class_name) return sympy_function(child) + def to_json(self): + """ + Method to serialise a SpecificFunction object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "function": self.function.__name__, + } + + return json_dict + class Arcsinh(SpecificFunction): """Arcsinh function.""" diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 62c08bf0fd..1f772bc325 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -101,3 +101,17 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): self._expected_size ) ) + + def to_json(self): + """ + Method to serialise an InputParameter object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domain": self.domain, + "expected_size": self._expected_size, + } + + return json_dict diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index cd0df4d077..9555dcaa34 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -290,3 +290,22 @@ def _function_evaluate(self, evaluated_children): else: # pragma: no cover raise ValueError("Invalid dimension: {0}".format(self.dimension)) + + # PL: think I need something here. presumably I can serialise function methods using just their names, then rehydrate them at the point they're read back in? + def to_json(self): + """ + Method to serialise an Interpolant object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + # "domains": self.domains, + "x": self.x.tolist(), + "y": self.y.tolist(), + "interpolator": self.interpolator, + "extrapolate": self.extrapolate, + # "entries_string": self.entries_string, + } + + return json_dict diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 10addae464..d8aa146fd9 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -49,6 +49,11 @@ def to_equation(self): else: return sympy.Symbol(self.name) + def to_json(self): + raise NotImplementedError( + "pybamm.Parameter: Serialisation is only implemented for post-discretisation." # PL: Come up with a better message! + ) + class FunctionParameter(pybamm.Symbol): """ @@ -221,3 +226,8 @@ def to_equation(self): return sympy.Symbol(self.print_name) else: return sympy.Symbol(self.name) + + def to_json(self): + raise NotImplementedError( + "pybamm.FunctionParameter: Serialisation is only implemented for post-discretisation." # PL: Come up with a better message! + ) diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 3149bf7bee..ae2b63560d 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -74,3 +74,12 @@ def to_equation(self): return sympy.Symbol(self.print_name) else: return self.value + + def to_json(self): + """ + Method to serialise a Symbol object into JSON. + """ + + json_dict = {"name": self.name, "id": self.id, "value": self.value} + + return json_dict diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 6ef8bee904..2c101e0a24 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -194,6 +194,29 @@ def _evaluate_for_shape(self): """ return np.nan * np.ones((self.size, 1)) + def to_json(self): + """ + Method to serialise a StateVector object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "y_slice": [ + { + "start": y.start, + "stop": y.stop, + "step": y.step, + } # are there ever more than 1? + for y in self.y_slices + ], + "evaluation_array": list(self.evaluation_array), + # "children": self.children, # might not need this, the anytree exporter handles children I think + } + + return json_dict + class StateVector(StateVectorBase): """ diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 5d28884ed5..037205fda0 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -985,3 +985,17 @@ def print_name(self, name): def to_equation(self): return sympy.Symbol(str(self.name)) + + def to_json(self): + """ + Method to serialise a Symbol object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + # "children": self.children, # the encoder deals with the children itself. + } + + return json_dict diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 7f9c45775c..efd0914464 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -316,6 +316,20 @@ def _evaluates_on_edges(self, dimension): """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False + def to_json(self): + """ + Method to serialise an Index object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "check_size": False, + } + + return json_dict + class SpatialOperator(UnaryOperator): """ @@ -581,6 +595,20 @@ def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" return sympy.Integral(child, sympy.Symbol("xn")) + def to_json(self): + """ + Method to serialise an Integral object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "integration_variable": self.integration_variable, # PL: This may be a (list of) symbols that need cycling through in a similar mannar to children + } + + return json_dict + class BaseIndefiniteIntegral(Integral): """ @@ -685,7 +713,8 @@ class DefiniteIntegralVector(SpatialOperator): Parameters ---------- variable : :class:`pybamm.Symbol` - The variable whose basis will be integrated over the entire domain + The variable whose basis will be integrated over the entire domain (will + become self.children[0]) vector_type : str, optional Whether to return a row or column vector (default is row) """ @@ -714,6 +743,20 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) + def to_json(self): + """ + Method to serialise a DefiniteIntegralVector object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "vector_type": self.vector_type, + } + + return json_dict + class BoundaryIntegral(SpatialOperator): """ @@ -771,6 +814,20 @@ def _evaluates_on_edges(self, dimension): """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False + def to_json(self): + """ + Method to serialise a BoundaryIntegral object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, # PL: Not sure if this exists, but might inherit from symbol + "region": self.region, + } + + return json_dict + class DeltaFunction(SpatialOperator): """ @@ -815,6 +872,20 @@ def evaluate_for_shape(self): return np.outer(child_eval, vec).reshape(-1, 1) + def to_json(self): + """ + Method to serialise a DeltaFunction object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "side": self.side, + } + + return json_dict + class BoundaryOperator(SpatialOperator): """ @@ -867,6 +938,20 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) + def to_json(self): + """ + Method to serialise a BoundaryOperator object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "side": self.side, + } + + return json_dict + class BoundaryValue(BoundaryOperator): """ diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index f9f7d94efc..1349901a9a 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -129,6 +129,25 @@ def to_equation(self): else: return self.name + def to_json( + self, + ): # PL: This may never be touched if once discretised, it's turned into a statevector/statevectordot type. + """ + Method to serialise a BoundaryOperator object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "bounds": self.bounds, # tuple + "print_name": self.print_name, # string + "scale": self.scale, # float/symbol + "reference": self.reference, # float/symbol + } + + return json_dict + class Variable(VariableBase): """ diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 41192dbe1f..7e1f9b060f 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -123,6 +123,50 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None + @classmethod + def deserialise(cls, properties: dict): + """ + Create a model instance from a serialised object. + """ + instance = cls.__new__(cls) + + instance.name = properties["name"] + instance._options = {} + instance._built = False + instance._built_fundamental = False + + # Initialise model with stored variables + instance.submodels = {} + instance._rhs = {} + instance._algebraic = {} + instance._initial_conditions = {} + instance._boundary_conditions = {} + instance._variables = pybamm.FuzzyDict({}) + instance._events = [] + instance._concatenated_rhs = properties["concatenated_rhs"] + instance._concatenated_algebraic = properties["concatenated_algebraic"] + instance._concatenated_initial_conditions = properties[ + "concatenated_initial_conditions" + ] + instance._mass_matrix = None + instance._mass_matrix_inv = None + instance._jacobian = None + instance._jacobian_algebraic = None + instance._parameters = None + instance._input_parameters = None + instance._parameter_info = None + instance._variables_casadi = {} + + # Default behaviour is to use the jacobian + instance.use_jacobian = True + instance.convert_to_format = "casadi" + + # Model has already been discretised + instance.is_discretised = True + instance.y_slices = None + + return instance + @property def name(self): return self._name diff --git a/pybamm/serialisation/serialisation.py b/pybamm/serialisation/serialisation.py new file mode 100644 index 0000000000..7075c7839d --- /dev/null +++ b/pybamm/serialisation/serialisation.py @@ -0,0 +1,232 @@ +import pybamm +from anytree.exporter import JsonExporter +from anytree.importer import JsonImporter +import json +import numpy as np +import pprint +import importlib +from scipy.sparse import csr_matrix, csr_array +from collections import defaultdict + + +class SymbolEncoder(json.JSONEncoder): + def default(self, node): + node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + if isinstance(node, pybamm.Symbol): + node_dict.update(node.to_json()) # this doesn't include children + node_dict["children"] = [] + for c in node.children: + node_dict["children"].append(self.default(c)) + + return node_dict + + json_obj = json.JSONEncoder.default(self, node) + node_dict["json"] = json_obj + return node_dict + + +## DECODE + + +class _Empty: + pass + + +def reconstruct_symbol(dct): + def recreate_slice(d): + return slice(d["start"], d["stop"], d["step"]) + + # decode non-symbol objects here + # now for pybamm + foo = _Empty() + parts = dct["py/object"].split(".") + try: + module = importlib.import_module(".".join(parts[:-1])) + except Exception as ex: + print(ex) + + class_ = getattr(module, parts[-1]) + foo.__class__ = class_ + + # PL: Need to finish off the various options here. + if isinstance(foo, pybamm.Scalar): + foo.__init__(dct["value"], name=dct["name"]) + + elif isinstance(foo, pybamm.BinaryOperator): + foo = foo._from_json(dct["children"][0], dct["children"][1], dct["domains"]) + + elif isinstance(foo, pybamm.Array): + if isinstance(dct["entries"], dict): + matrix = csr_array( + ( + dct["entries"]["data"], + dct["entries"]["row_indices"], + dct["entries"]["column_pointers"], + ), + shape=dct["entries"]["shape"], + ) + else: + matrix = dct["entries"] + foo.__init__( + matrix, + name=dct["name"], + domains=dct["domains"], + # entries_string=dct["entries_string"], + ) + + elif isinstance(foo, pybamm.StateVectorBase): + y_slices = [recreate_slice(d) for d in dct["y_slice"]] + foo.__init__( + *y_slices, + name=dct["name"], + domains=dct["domains"], + evaluation_array=dct["evaluation_array"], + ) + + elif isinstance(foo, pybamm.IndependentVariable): + if isinstance(foo, pybamm.Time): + foo.__init__() + else: + foo.__init__(dct["name"], domains=dct["domains"]) + + elif isinstance(foo, pybamm.InputParameter): + foo.__init__( + dct["name"], domain=dct["domain"], expected_size=dct["expected_size"] + ) + + elif isinstance(foo, pybamm.SpecificFunction): + foo.__init__(dct["children"][0]) + + elif isinstance(foo, pybamm.Function): + func = getattr( + np, dct["function"] + ) # don't think this will work for self-defined functions + foo.__init__( + func, + name=dct["name"], + derivative=dct["derivative"], + differentiated_function=dct["differentiated_function"], + ) + + elif isinstance(foo, pybamm.DomainConcatenation): + + def repack_defaultDict(slices): + slices = defaultdict(list, slices) + for domain, sls in slices.items(): + sls = [recreate_slice(s) for s in sls] + slices[domain] = sls + return slices + + main_slice = repack_defaultDict(dct["slices"]) + child_slice = [repack_defaultDict(s) for s in dct["children_slices"]] + + foo = foo._from_json( + dct["children"], + dct["size"], + main_slice, + child_slice, + dct["secondary_dimensions_npts"], + dct["domains"], + ) + + elif isinstance(foo, pybamm.NumpyConcatenation): + foo = foo._from_json( + dct["children"], + dct["domains"], + ) + # interpolant + # check various Unary operators, they differ + # VariableBase + # ... + elif isinstance(foo, pybamm.Symbol): + foo.__init__(dct["name"], children=dct["children"], domains=dct["domains"]) + + return foo + + +def reconstruct_epression_tree(node): + if "children" in node: + for i, c in enumerate(node["children"]): + child_obj = reconstruct_epression_tree(c) + node["children"][i] = child_obj + + obj = reconstruct_symbol(node) + + return obj + + +## Run tests +model = pybamm.lithium_ion.DFN() +geometry = model.default_geometry +param = model.default_parameter_values +param.process_model(model) +param.process_geometry(geometry) +mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) +disc = pybamm.Discretisation(mesh, model.default_spatial_methods) +disc.process_model(model) + +# # tested all individual trees in rhs +# # tree1 = list(model.rhs.items())[2][1] +# tree1 = ( +# model.y_slices +# ) # Worked: concatenated_rhs, concat_initial_conditions, concatenated_algebraic. +# # Do we need the 'unconcatenated' rhs etc? if not, this gets much easier. +# tree1.visualise("tree1.png") + +# json_tree1 = SymbolEncoder().default(tree1) +# with open("test_tree1.json", "w") as f: +# json.dump(json_tree1, f) + +# # pprint.pprint(json_tree1, sort_dicts=False) + +# with open("test_tree1.json", "r") as f: +# data = json.load(f) + +# tree1_recon = reconstruct_epression_tree(data) + +# print(tree1 == tree1_recon) + + +# tree1_recon.visualise("recon1.png") + +solver_initial = model.default_solver +solution_initial = solver_initial.solve(model, [0, 3600]) + +# pybamm.plot(solution_initial) +# solution_initial.plot() + +model_json = { + "py/object": str(type(model))[8:-2], + "py/id": id(model), + "name": model.name, + "concatenated_rhs": SymbolEncoder().default(model._concatenated_rhs), + "concatenated_algebraic": SymbolEncoder().default(model._concatenated_algebraic), + "concatenated_initial_conditions": SymbolEncoder().default( + model._concatenated_initial_conditions + ), +} + +# file_name = f"test_{model.name}_stored" +with open("test_full_model.json", "w") as f: + json.dump(model_json, f) + +with open("test_full_model.json", "r") as f: + model_data = json.load(f) + +recon_model_dict = { + "name": model_data["name"], + "concatenated_rhs": reconstruct_epression_tree(model_data["concatenated_rhs"]), + "concatenated_algebraic": reconstruct_epression_tree( + model_data["concatenated_algebraic"] + ), + "concatenated_initial_conditions": reconstruct_epression_tree( + model_data["concatenated_initial_conditions"] + ), +} + +new_model = pybamm.lithium_ion.DFN.deserialise(recon_model_dict) + +new_solver = new_model.default_solver +new_solution = new_solver.solve(model, [0, 3600]) + +# THIS WORKS!!! From 70b765d6855fa25c9ba6358c696a7fdaa278d5ce Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 18 Aug 2023 16:05:20 +0000 Subject: [PATCH 02/29] Move deserialisation functions to Symbol classes Creates _from_json() functionality --- pybamm/expression_tree/array.py | 26 +++- pybamm/expression_tree/binary_operators.py | 57 +++++---- pybamm/expression_tree/concatenations.py | 32 +++-- pybamm/expression_tree/functions.py | 119 +++++++++++++++--- .../expression_tree/independent_variable.py | 16 +++ pybamm/expression_tree/input_parameter.py | 12 ++ pybamm/expression_tree/interpolant.py | 18 +-- pybamm/expression_tree/scalar.py | 8 ++ pybamm/expression_tree/state_vector.py | 15 +++ pybamm/expression_tree/symbol.py | 20 +++ pybamm/expression_tree/unary_operators.py | 51 ++------ pybamm/expression_tree/variable.py | 18 +-- pybamm/models/base_model.py | 26 +--- pybamm/serialisation/serialisation.py | 72 +---------- 14 files changed, 266 insertions(+), 224 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index d0ba8d1296..c2fcbc4a11 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -3,7 +3,7 @@ # import numpy as np import sympy -from scipy.sparse import csr_matrix, issparse +from scipy.sparse import csr_matrix, issparse, csr_array import pybamm @@ -57,6 +57,30 @@ def __init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + if isinstance(snippet["entries"], dict): + matrix = csr_array( + ( + snippet["entries"]["data"], + snippet["entries"]["row_indices"], + snippet["entries"]["column_pointers"], + ), + shape=snippet["entries"]["shape"], + ) + else: + matrix = snippet["entries"] + + instance.__init__( + matrix, + name=snippet["name"], + domains=snippet["domains"], + ) + + return instance + @property def entries(self): return self._entries diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 05520a081a..30a81ee416 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -69,13 +69,16 @@ def __init__(self, name, left, right): self.right = self.children[1] @classmethod - def _from_json(cls, name, left, right, domains): + def _from_json(cls, name, snippet: dict): """Use to instantiate when deserialising; discretisation has already occured so pre-processing of binaries is not necessary.""" + instance = cls.__new__(cls) super(BinaryOperator, instance).__init__( - name, children=[left, right], domains=domains + name, + children=[snippet["children"][0], snippet["children"][1]], + domains=snippet["domains"], ) instance.left = instance.children[0] instance.right = instance.children[1] @@ -189,9 +192,9 @@ def __init__(self, left, right): super().__init__("**", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("**", left, right, domains) + instance = super()._from_json("**", snippet) return instance def _diff(self, variable): @@ -236,9 +239,9 @@ def __init__(self, left, right): super().__init__("+", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("+", left, right, domains) + instance = super()._from_json("+", snippet) return instance def _diff(self, variable): @@ -265,9 +268,9 @@ def __init__(self, left, right): super().__init__("-", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("-", left, right, domains) + instance = super()._from_json("-", snippet) return instance def _diff(self, variable): @@ -296,9 +299,9 @@ def __init__(self, left, right): super().__init__("*", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("*", left, right, domains) + instance = super()._from_json("*", snippet) return instance def _diff(self, variable): @@ -338,10 +341,10 @@ def __init__(self, left, right): super().__init__("@", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" # instance = super(MatrixMultiplication, cls)._from_json("@", left, right) - instance = super()._from_json("@", left, right, domains) + instance = super()._from_json("@", snippet) return instance def diff(self, variable): @@ -392,9 +395,9 @@ def __init__(self, left, right): super().__init__("/", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("/", left, right, domains) + instance = super()._from_json("/", snippet) return instance def _diff(self, variable): @@ -442,9 +445,9 @@ def __init__(self, left, right): super().__init__("inner product", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("inner product", left, right, domains) + instance = super()._from_json("inner product", snippet) return instance def _diff(self, variable): @@ -517,9 +520,9 @@ def __init__(self, left, right): super().__init__("==", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("==", left, right, domains) + instance = super()._from_json("==", snippet) return instance def diff(self, variable): @@ -569,9 +572,11 @@ def __init__(self, name, left, right): super().__init__(name, left, right) @classmethod - def _from_json(cls, name, left, right): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json(name, left, right) + instance = super()._from_json( + snippet["name"], snippet["children"][0], snippet["children"][1] + ) return instance def diff(self, variable): @@ -640,9 +645,9 @@ def __init__(self, left, right): super().__init__("%", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("%", left, right, domains) + instance = super()._from_json("%", snippet) return instance def _diff(self, variable): @@ -684,9 +689,9 @@ def __init__(self, left, right): super().__init__("minimum", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("minimum", left, right, domains) + instance = super()._from_json("minimum", snippet) return instance def __str__(self): @@ -726,9 +731,9 @@ def __init__(self, left, right): super().__init__("maximum", left, right) @classmethod - def _from_json(cls, left, right, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("maximum", left, right, domains) + instance = super()._from_json("maximum", snippet) return instance def __str__(self): diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 5e678af95f..af3db72846 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -45,6 +45,7 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None): @classmethod def _from_json(cls, *children, name, domains, concat_fun=None): + # PL: update this one - I guess we still want it to take 'snippet' rather than the list? to be the same as the others? instance = cls.__new__(cls) super(Concatenation, instance).__init__(name, children, domains=domains) @@ -193,12 +194,12 @@ def __init__(self, *children): ) @classmethod - def _from_json(cls, children, domains): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" instance = super()._from_json( - *children, + *snippet["children"], name="numpy_concatenation", - domains=domains, + domains=snippet["domains"], concat_fun=np.concatenate ) @@ -273,18 +274,27 @@ def __init__(self, children, full_mesh, copy_this=None): self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts @classmethod - def _from_json( - cls, children, size, slices, children_slices, secondary_dimensions_npts, domains - ): + def _from_json(cls, snippet: dict): """See :meth:`pybamm.Concatenation._from_json()`.""" instance = super()._from_json( - *children, name="domain_concatenation", domains=domains + *snippet["children"], + name="domain_concatenation", + domains=snippet["domains"] ) - instance._size = size - instance._slices = slices - instance._children_slices = children_slices - instance.secondary_dimensions_npts = secondary_dimensions_npts + def repack_defaultDict(slices): + slices = defaultdict(list, slices) + for domain, sls in slices.items(): + sls = [slice(s["start"], s["stop"], s["step"]) for s in sls] + slices[domain] = sls + return slices + + instance._size = snippet["size"] + instance._slices = repack_defaultDict(snippet["slices"]) + instance._children_slices = [ + repack_defaultDict(s) for s in snippet["children_slices"] + ] + instance.secondary_dimensions_npts = snippet["secondary_dimensions_npts"] return instance diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index c759cc0b51..17732c7ba4 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -7,6 +7,7 @@ import numpy as np import sympy from scipy import special +from typing import Callable import pybamm @@ -211,27 +212,8 @@ def to_equation(self): eq_list.append(eq) return self._sympy_operator(*eq_list) - # PL: think I need something here. presumably I can serialise function methods using just their names, then rehydrate them at the point they're read back in? def to_json(self): - """ - Method to serialise a Function object into JSON. - """ - - try: - func_name = self.function.__name__ - except: - raise Exception - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, - "function": func_name, # PL: actually put name here - "derivative": self.derivative, - "differentiated_function": self.differentiated_function, # PL: same here (although is this defined? or is it just written out...) - } - - return json_dict + raise NotImplementedError() def simplified_function(func_class, child): @@ -266,6 +248,25 @@ class SpecificFunction(Function): def __init__(self, function, child): super().__init__(function, child) + @classmethod + def _from_json(cls, function: Callable, snippet: dict): + """ + Reconstructs a SpecificFunction instance during deserialisation of a JSON file. + + Parameters + ---------- + function : method + Function to be applied to child + snippet: dict + Contains the child to apply the function to + """ + + instance = cls.__new__(cls) + + super(SpecificFunction, instance).__init__(function, snippet["children"][0]) + + return instance + def _function_new_copy(self, children): """See :meth:`pybamm.Function._function_new_copy()`""" return pybamm.simplify_if_constant(self.__class__(*children)) @@ -296,6 +297,12 @@ class Arcsinh(SpecificFunction): def __init__(self, child): super().__init__(np.arcsinh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.arcsinh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Symbol._function_diff()`.""" return 1 / sqrt(children[0] ** 2 + 1) @@ -316,6 +323,12 @@ class Arctan(SpecificFunction): def __init__(self, child): super().__init__(np.arctan, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.arctan, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return 1 / (children[0] ** 2 + 1) @@ -336,6 +349,12 @@ class Cos(SpecificFunction): def __init__(self, child): super().__init__(np.cos, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.cos, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Symbol._function_diff()`.""" return -sin(children[0]) @@ -352,6 +371,12 @@ class Cosh(SpecificFunction): def __init__(self, child): super().__init__(np.cosh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.cosh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return sinh(children[0]) @@ -368,6 +393,12 @@ class Erf(SpecificFunction): def __init__(self, child): super().__init__(special.erf, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(special.erf, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return 2 / np.sqrt(np.pi) * exp(-children[0] ** 2) @@ -389,6 +420,12 @@ class Exp(SpecificFunction): def __init__(self, child): super().__init__(np.exp, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.exp, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return exp(children[0]) @@ -405,6 +442,12 @@ class Log(SpecificFunction): def __init__(self, child): super().__init__(np.log, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.log, snippet) + return instance + def _function_evaluate(self, evaluated_children): # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): @@ -435,6 +478,12 @@ class Max(SpecificFunction): def __init__(self, child): super().__init__(np.max, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.max, snippet) + return instance + def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" # Max will always return a scalar @@ -455,6 +504,12 @@ class Min(SpecificFunction): def __init__(self, child): super().__init__(np.min, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.min, snippet) + return instance + def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" # Min will always return a scalar @@ -480,6 +535,12 @@ class Sin(SpecificFunction): def __init__(self, child): super().__init__(np.sin, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.sin, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return cos(children[0]) @@ -496,6 +557,12 @@ class Sinh(SpecificFunction): def __init__(self, child): super().__init__(np.sinh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.sinh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return cosh(children[0]) @@ -512,6 +579,12 @@ class Sqrt(SpecificFunction): def __init__(self, child): super().__init__(np.sqrt, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.sqrt, snippet) + return instance + def _function_evaluate(self, evaluated_children): # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): @@ -533,6 +606,12 @@ class Tanh(SpecificFunction): def __init__(self, child): super().__init__(np.tanh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.tanh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return sech(children[0]) ** 2 diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index efeb73f8bc..665bfdb344 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -34,6 +34,14 @@ def __init__(self, name, domain=None, auxiliary_domains=None, domains=None): name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__(snippet["name"], domains=snippet["domains"]) + + return instance + def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) @@ -58,6 +66,14 @@ class Time(IndependentVariable): def __init__(self): super().__init__("time") + @classmethod + def _to_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__("time") + + return instance + def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Time() diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 1f772bc325..e66a4c8cdc 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -35,6 +35,18 @@ def __init__(self, name, domain=None, expected_size=None): self._expected_size = expected_size super().__init__(name, domain=domain) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__( + snippet["name"], + domain=snippet["domain"], + expected_size=snippet["expected_size"], + ) + + return instance + def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" new_input_parameter = InputParameter( diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 9555dcaa34..16bbe88d7e 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -291,21 +291,5 @@ def _function_evaluate(self, evaluated_children): else: # pragma: no cover raise ValueError("Invalid dimension: {0}".format(self.dimension)) - # PL: think I need something here. presumably I can serialise function methods using just their names, then rehydrate them at the point they're read back in? def to_json(self): - """ - Method to serialise an Interpolant object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - # "domains": self.domains, - "x": self.x.tolist(), - "y": self.y.tolist(), - "interpolator": self.interpolator, - "extrapolate": self.extrapolate, - # "entries_string": self.entries_string, - } - - return json_dict + raise NotImplementedError diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index ae2b63560d..9f7d1aa368 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -29,6 +29,14 @@ def __init__(self, value, name=None): super().__init__(name) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__(snippet["value"], name=snippet["name"]) + + return instance + def __str__(self): return str(self.value) diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 2c101e0a24..9a414dc049 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -73,6 +73,21 @@ def __init__( domains=domains, ) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + y_slices = [slice(s["start"], s["stop"], s["step"]) for s in snippet["y_slice"]] + + instance.__init__( + *y_slices, + name=snippet["name"], + domains=snippet["domains"], + evaluation_array=snippet["evaluation_array"], + ) + + return instance + @property def y_slices(self): return self._y_slices diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 037205fda0..b0747090cd 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -234,6 +234,26 @@ def __init__( ): self.test_shape() + @classmethod + def _from_json(cls, snippet: dict): + """ + Reconstructs a Symbol instance during deserialisation of a JSON file. + + Parameters + ---------- + snippet: dict + Contains the information needed to reconstruct a specific instance. + At minimum, should contain "name", "children" and "domains". + """ + + instance = cls.__new__(cls) + + instance.__init__( + snippet["name"], children=snippet["children"], domains=snippet["domains"] + ) + + return instance + @property def children(self): """ diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index efd0914464..c2fd6c2232 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -353,6 +353,15 @@ class with a :class:`Matrix` def __init__(self, name, child, domains=None): super().__init__(name, child, domains) + def diff(self, variable): + """See :meth:`pybamm.Symbol.diff()`.""" + # We shouldn't need this + raise NotImplementedError + + def to_json(self): + # Will not be present in a discretised model + raise NotImplementedError + class Gradient(SpatialOperator): """ @@ -814,20 +823,6 @@ def _evaluates_on_edges(self, dimension): """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False - def to_json(self): - """ - Method to serialise a BoundaryIntegral object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, # PL: Not sure if this exists, but might inherit from symbol - "region": self.region, - } - - return json_dict - class DeltaFunction(SpatialOperator): """ @@ -872,20 +867,6 @@ def evaluate_for_shape(self): return np.outer(child_eval, vec).reshape(-1, 1) - def to_json(self): - """ - Method to serialise a DeltaFunction object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, - "side": self.side, - } - - return json_dict - class BoundaryOperator(SpatialOperator): """ @@ -938,20 +919,6 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def to_json(self): - """ - Method to serialise a BoundaryOperator object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, - "side": self.side, - } - - return json_dict - class BoundaryValue(BoundaryOperator): """ diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 1349901a9a..8aa2b1d707 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -131,22 +131,8 @@ def to_equation(self): def to_json( self, - ): # PL: This may never be touched if once discretised, it's turned into a statevector/statevectordot type. - """ - Method to serialise a BoundaryOperator object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, - "bounds": self.bounds, # tuple - "print_name": self.print_name, # string - "scale": self.scale, # float/symbol - "reference": self.reference, # float/symbol - } - - return json_dict + ): + raise NotImplementedError class Variable(VariableBase): diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 7e1f9b060f..4def187f41 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -123,6 +123,7 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None + # PL: Next up, how to pass in the non-standard variables, if necessary. @classmethod def deserialise(cls, properties: dict): """ @@ -130,40 +131,17 @@ def deserialise(cls, properties: dict): """ instance = cls.__new__(cls) - instance.name = properties["name"] - instance._options = {} - instance._built = False - instance._built_fundamental = False + instance.__init__(name=properties["name"]) # Initialise model with stored variables - instance.submodels = {} - instance._rhs = {} - instance._algebraic = {} - instance._initial_conditions = {} - instance._boundary_conditions = {} - instance._variables = pybamm.FuzzyDict({}) - instance._events = [] instance._concatenated_rhs = properties["concatenated_rhs"] instance._concatenated_algebraic = properties["concatenated_algebraic"] instance._concatenated_initial_conditions = properties[ "concatenated_initial_conditions" ] - instance._mass_matrix = None - instance._mass_matrix_inv = None - instance._jacobian = None - instance._jacobian_algebraic = None - instance._parameters = None - instance._input_parameters = None - instance._parameter_info = None - instance._variables_casadi = {} - - # Default behaviour is to use the jacobian - instance.use_jacobian = True - instance.convert_to_format = "casadi" # Model has already been discretised instance.is_discretised = True - instance.y_slices = None return instance diff --git a/pybamm/serialisation/serialisation.py b/pybamm/serialisation/serialisation.py index 7075c7839d..606fd82688 100644 --- a/pybamm/serialisation/serialisation.py +++ b/pybamm/serialisation/serialisation.py @@ -47,68 +47,9 @@ def recreate_slice(d): class_ = getattr(module, parts[-1]) foo.__class__ = class_ + # foo = foo._from_json(dct) -> PL: This is what we want eventually - # PL: Need to finish off the various options here. - if isinstance(foo, pybamm.Scalar): - foo.__init__(dct["value"], name=dct["name"]) - - elif isinstance(foo, pybamm.BinaryOperator): - foo = foo._from_json(dct["children"][0], dct["children"][1], dct["domains"]) - - elif isinstance(foo, pybamm.Array): - if isinstance(dct["entries"], dict): - matrix = csr_array( - ( - dct["entries"]["data"], - dct["entries"]["row_indices"], - dct["entries"]["column_pointers"], - ), - shape=dct["entries"]["shape"], - ) - else: - matrix = dct["entries"] - foo.__init__( - matrix, - name=dct["name"], - domains=dct["domains"], - # entries_string=dct["entries_string"], - ) - - elif isinstance(foo, pybamm.StateVectorBase): - y_slices = [recreate_slice(d) for d in dct["y_slice"]] - foo.__init__( - *y_slices, - name=dct["name"], - domains=dct["domains"], - evaluation_array=dct["evaluation_array"], - ) - - elif isinstance(foo, pybamm.IndependentVariable): - if isinstance(foo, pybamm.Time): - foo.__init__() - else: - foo.__init__(dct["name"], domains=dct["domains"]) - - elif isinstance(foo, pybamm.InputParameter): - foo.__init__( - dct["name"], domain=dct["domain"], expected_size=dct["expected_size"] - ) - - elif isinstance(foo, pybamm.SpecificFunction): - foo.__init__(dct["children"][0]) - - elif isinstance(foo, pybamm.Function): - func = getattr( - np, dct["function"] - ) # don't think this will work for self-defined functions - foo.__init__( - func, - name=dct["name"], - derivative=dct["derivative"], - differentiated_function=dct["differentiated_function"], - ) - - elif isinstance(foo, pybamm.DomainConcatenation): + if isinstance(foo, pybamm.DomainConcatenation): def repack_defaultDict(slices): slices = defaultdict(list, slices) @@ -134,12 +75,9 @@ def repack_defaultDict(slices): dct["children"], dct["domains"], ) - # interpolant - # check various Unary operators, they differ - # VariableBase - # ... - elif isinstance(foo, pybamm.Symbol): - foo.__init__(dct["name"], children=dct["children"], domains=dct["domains"]) + + else: + foo = foo._from_json(dct) return foo From 4ea81086592be4282df401bcb7f2eb6f6c1b953d Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 24 Aug 2023 16:37:39 +0000 Subject: [PATCH 03/29] Create Serialise class Stores save_/load_model functions. Currently working for default models. Add errors, make accessible from Simulation --- pybamm/__init__.py | 1 + pybamm/expression_tree/broadcasts.py | 8 +- .../expression_tree/operations/serialise.py | 182 ++++++++++++++++++ pybamm/expression_tree/parameter.py | 16 +- pybamm/expression_tree/unary_operators.py | 39 +--- pybamm/models/base_model.py | 36 ++++ pybamm/serialisation/serialisation.py | 170 ---------------- pybamm/simulation.py | 22 +++ 8 files changed, 271 insertions(+), 203 deletions(-) create mode 100644 pybamm/expression_tree/operations/serialise.py delete mode 100644 pybamm/serialisation/serialisation.py diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 6c2636ba51..d7b957e1c9 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -93,6 +93,7 @@ from .expression_tree.operations.jacobian import Jacobian from .expression_tree.operations.convert_to_casadi import CasadiConverter from .expression_tree.operations.unpack_symbols import SymbolUnpacker +from .models.base_model import load_model # # Model classes diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 45c37a55f0..a9bd5c2ee2 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -52,7 +52,13 @@ def _diff(self, variable): def to_json(self): raise NotImplementedError( - "pybamm.Broadcast: Serialisation is only implemented for post-discretisation." # PL: Come up with a better message! + "pybamm.Broadcast: Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.Broadcast: Please use a discretised model when reading in from JSON." ) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py new file mode 100644 index 0000000000..c88f32e602 --- /dev/null +++ b/pybamm/expression_tree/operations/serialise.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import pybamm +from datetime import datetime +import json +import importlib +import numpy as np + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pybamm import BaseBatteryModel + + +class Serialise: + """ + Converts a discretised model to and from a JSON file. + + """ + + def __init__(self): + pass + + class _SymbolEncoder(json.JSONEncoder): + """Converts PyBaMM symbols into a JSON-serialisable format""" + + def default(self, node: dict): + node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + if isinstance(node, pybamm.Symbol): + node_dict.update(node.to_json()) # this doesn't include children + node_dict["children"] = [] + for c in node.children: + node_dict["children"].append(self.default(c)) + + return node_dict + + json_obj = json.JSONEncoder.default(self, node) + node_dict["json"] = json_obj + return node_dict + + class _Empty: + """A dummy class to aid deserialisation""" + + pass + + def save_model(self, model, filename=None): + """ + Saves a discretised model to a JSON file. + + As the model is discretised and ready to solve, only the right hand side, + algebraic and initial condition variables are saved. + + Parameters + ---------- + model: : :class:`pybamm.BaseModel` + The discretised model to be saved + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be + created based on the model name, and the current datetime. + """ + if model.is_discretised == False: + raise NotImplementedError( + "PyBaMM can only serialise a discretised, ready-to-solve model." + ) + + model_json = { + "py/object": str(type(model))[8:-2], + "py/id": id(model), + "name": model.name, + "options": model.options, + "bounds": [bound.tolist() for bound in model.bounds], + "concatenated_rhs": self._SymbolEncoder().default(model._concatenated_rhs), + "concatenated_algebraic": self._SymbolEncoder().default( + model._concatenated_algebraic + ), + "concatenated_initial_conditions": self._SymbolEncoder().default( + model._concatenated_initial_conditions + ), + } + + if filename is None: + filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M_%S") + + with open(filename + ".json", "w") as f: + json.dump(model_json, f) + + def load_model(self, filename: str, battery_model: BaseBatteryModel = None): + """ + Loads a discretised, ready to solve model into PyBaMM. + + A new pybamm battery model instance will be created, which can be solved + and the results plotted as usual. + + Currently only available for pybamm models which have previously been written + out using the `save_model()` option. + + Warning: This only loads in discretised models. If you wish to make edits to the + model or initial conditions, a new model will need to be constructed seperately. + + Parameters + ---------- + + filename: str + Path to the JSON file containing the serialised model file + battery_model: :class: pybamm.BaseBatteryModel, optional + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override + any model names within the file. If None, the function will look for the saved object + path, present if the original model came from PyBaMM. + """ + + with open(filename, "r") as f: + model_data = json.load(f) + + recon_model_dict = { + "name": model_data["name"], + "options": model_data["options"], + "bounds": tuple(np.array(bound) for bound in model_data["bounds"]), + "concatenated_rhs": self._reconstruct_epression_tree( + model_data["concatenated_rhs"] + ), + "concatenated_algebraic": self._reconstruct_epression_tree( + model_data["concatenated_algebraic"] + ), + "concatenated_initial_conditions": self._reconstruct_epression_tree( + model_data["concatenated_initial_conditions"] + ), + } + + if battery_model: + return battery_model.deserialise(recon_model_dict) + + if "py/object" in model_data.keys(): + model_framework = self._get_pybamm_class(model_data) + return model_framework.deserialise(recon_model_dict) + + raise TypeError( + """ + The PyBaMM battery model to use has not been provided. + """ + ) + + def _get_pybamm_class(self, snippet: dict): + """Find a pybamm class to initialise from object path""" + empty_class = self._Empty() + parts = snippet["py/object"].split(".") + try: + module = importlib.import_module(".".join(parts[:-1])) + except Exception as ex: + print(ex) + + class_ = getattr(module, parts[-1]) + empty_class.__class__ = class_ + + return empty_class + + def _reconstruct_symbol(self, dct: dict): + """Reconstruct an individual pybamm Symbol""" + symbol_class = self._get_pybamm_class(dct) + symbol = symbol_class._from_json(dct) + return symbol + + def _reconstruct_epression_tree(self, node: dict): + """ + Loop through an expression tree creating pybamm Symbol classes + + Conducts post-order tree traversal to turn each tree node into a + `pybamm.Symbol` class, starting from leaf nodes without children and + working upwards. + + Parameters + ---------- + node: dict + A node in an expression tree. + """ + if "children" in node: + for i, c in enumerate(node["children"]): + child_obj = self._reconstruct_epression_tree(c) + node["children"][i] = child_obj + + obj = self._reconstruct_symbol(node) + + return obj diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index d8aa146fd9..abf50faa75 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -51,7 +51,13 @@ def to_equation(self): def to_json(self): raise NotImplementedError( - "pybamm.Parameter: Serialisation is only implemented for post-discretisation." # PL: Come up with a better message! + "pybamm.Parameter: Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.Parameter: Please use a discretised model when reading in from JSON." ) @@ -229,5 +235,11 @@ def to_equation(self): def to_json(self): raise NotImplementedError( - "pybamm.FunctionParameter: Serialisation is only implemented for post-discretisation." # PL: Come up with a better message! + "pybamm.FunctionParameter: Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.FunctionParameter: Please use a discretised model when reading in from JSON." ) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index c2fd6c2232..a828b8442b 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -359,8 +359,15 @@ def diff(self, variable): raise NotImplementedError def to_json(self): - # Will not be present in a discretised model - raise NotImplementedError + raise NotImplementedError( + "pybamm.SpatialOperator: Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.SpatialOperator: Please use a discretised model when reading in from JSON." + ) class Gradient(SpatialOperator): @@ -604,20 +611,6 @@ def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" return sympy.Integral(child, sympy.Symbol("xn")) - def to_json(self): - """ - Method to serialise an Integral object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, - "integration_variable": self.integration_variable, # PL: This may be a (list of) symbols that need cycling through in a similar mannar to children - } - - return json_dict - class BaseIndefiniteIntegral(Integral): """ @@ -752,20 +745,6 @@ def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) - def to_json(self): - """ - Method to serialise a DefiniteIntegralVector object into JSON. - """ - - json_dict = { - "name": self.name, - "id": self.id, - "domains": self.domains, - "vector_type": self.vector_type, - } - - return json_dict - class BoundaryIntegral(SpatialOperator): """ diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 4def187f41..f407a15c08 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -10,6 +10,7 @@ import pybamm from pybamm.expression_tree.operations.latexify import Latexify +from pybamm.expression_tree.operations.serialise import Serialise class BaseModel: @@ -134,6 +135,7 @@ def deserialise(cls, properties: dict): instance.__init__(name=properties["name"]) # Initialise model with stored variables + instance._options = properties["options"] # For information only instance._concatenated_rhs = properties["concatenated_rhs"] instance._concatenated_algebraic = properties["concatenated_algebraic"] instance._concatenated_initial_conditions = properties[ @@ -143,6 +145,12 @@ def deserialise(cls, properties: dict): # Model has already been discretised instance.is_discretised = True + instance.len_rhs = instance.concatenated_rhs.size + instance.len_alg = instance.concatenated_algebraic.size + instance.len_rhs_and_alg = instance.len_rhs + instance.len_alg + + instance.bounds = properties["bounds"] + return instance @property @@ -1132,6 +1140,34 @@ def process_parameters_and_discretise(self, symbol, parameter_values, disc): return disc_symbol + def save_model(self, filename=None): + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. + """ + Serialise().save_model(self, filename=filename) + + +def load_model(filename, battery_model: BaseModel = None): + """ + Load in a saved model from a JSON file + + Parameters + ---------- + filename: str + Path to the JSON file containing the serialised model file + battery_model: :class: pybamm.BaseBatteryModel, optional + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override + any model names within the file. If None, the function will look for the saved object + path, present if the original model came from PyBaMM. + """ + return Serialise().load_model(filename, battery_model) + # helper functions for finding symbols def find_symbol_in_tree(tree, name): diff --git a/pybamm/serialisation/serialisation.py b/pybamm/serialisation/serialisation.py deleted file mode 100644 index 606fd82688..0000000000 --- a/pybamm/serialisation/serialisation.py +++ /dev/null @@ -1,170 +0,0 @@ -import pybamm -from anytree.exporter import JsonExporter -from anytree.importer import JsonImporter -import json -import numpy as np -import pprint -import importlib -from scipy.sparse import csr_matrix, csr_array -from collections import defaultdict - - -class SymbolEncoder(json.JSONEncoder): - def default(self, node): - node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} - if isinstance(node, pybamm.Symbol): - node_dict.update(node.to_json()) # this doesn't include children - node_dict["children"] = [] - for c in node.children: - node_dict["children"].append(self.default(c)) - - return node_dict - - json_obj = json.JSONEncoder.default(self, node) - node_dict["json"] = json_obj - return node_dict - - -## DECODE - - -class _Empty: - pass - - -def reconstruct_symbol(dct): - def recreate_slice(d): - return slice(d["start"], d["stop"], d["step"]) - - # decode non-symbol objects here - # now for pybamm - foo = _Empty() - parts = dct["py/object"].split(".") - try: - module = importlib.import_module(".".join(parts[:-1])) - except Exception as ex: - print(ex) - - class_ = getattr(module, parts[-1]) - foo.__class__ = class_ - # foo = foo._from_json(dct) -> PL: This is what we want eventually - - if isinstance(foo, pybamm.DomainConcatenation): - - def repack_defaultDict(slices): - slices = defaultdict(list, slices) - for domain, sls in slices.items(): - sls = [recreate_slice(s) for s in sls] - slices[domain] = sls - return slices - - main_slice = repack_defaultDict(dct["slices"]) - child_slice = [repack_defaultDict(s) for s in dct["children_slices"]] - - foo = foo._from_json( - dct["children"], - dct["size"], - main_slice, - child_slice, - dct["secondary_dimensions_npts"], - dct["domains"], - ) - - elif isinstance(foo, pybamm.NumpyConcatenation): - foo = foo._from_json( - dct["children"], - dct["domains"], - ) - - else: - foo = foo._from_json(dct) - - return foo - - -def reconstruct_epression_tree(node): - if "children" in node: - for i, c in enumerate(node["children"]): - child_obj = reconstruct_epression_tree(c) - node["children"][i] = child_obj - - obj = reconstruct_symbol(node) - - return obj - - -## Run tests -model = pybamm.lithium_ion.DFN() -geometry = model.default_geometry -param = model.default_parameter_values -param.process_model(model) -param.process_geometry(geometry) -mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) -disc = pybamm.Discretisation(mesh, model.default_spatial_methods) -disc.process_model(model) - -# # tested all individual trees in rhs -# # tree1 = list(model.rhs.items())[2][1] -# tree1 = ( -# model.y_slices -# ) # Worked: concatenated_rhs, concat_initial_conditions, concatenated_algebraic. -# # Do we need the 'unconcatenated' rhs etc? if not, this gets much easier. -# tree1.visualise("tree1.png") - -# json_tree1 = SymbolEncoder().default(tree1) -# with open("test_tree1.json", "w") as f: -# json.dump(json_tree1, f) - -# # pprint.pprint(json_tree1, sort_dicts=False) - -# with open("test_tree1.json", "r") as f: -# data = json.load(f) - -# tree1_recon = reconstruct_epression_tree(data) - -# print(tree1 == tree1_recon) - - -# tree1_recon.visualise("recon1.png") - -solver_initial = model.default_solver -solution_initial = solver_initial.solve(model, [0, 3600]) - -# pybamm.plot(solution_initial) -# solution_initial.plot() - -model_json = { - "py/object": str(type(model))[8:-2], - "py/id": id(model), - "name": model.name, - "concatenated_rhs": SymbolEncoder().default(model._concatenated_rhs), - "concatenated_algebraic": SymbolEncoder().default(model._concatenated_algebraic), - "concatenated_initial_conditions": SymbolEncoder().default( - model._concatenated_initial_conditions - ), -} - -# file_name = f"test_{model.name}_stored" -with open("test_full_model.json", "w") as f: - json.dump(model_json, f) - -with open("test_full_model.json", "r") as f: - model_data = json.load(f) - -recon_model_dict = { - "name": model_data["name"], - "concatenated_rhs": reconstruct_epression_tree(model_data["concatenated_rhs"]), - "concatenated_algebraic": reconstruct_epression_tree( - model_data["concatenated_algebraic"] - ), - "concatenated_initial_conditions": reconstruct_epression_tree( - model_data["concatenated_initial_conditions"] - ), -} - -new_model = pybamm.lithium_ion.DFN.deserialise(recon_model_dict) - -new_solver = new_model.default_solver -new_solution = new_solver.solve(model, [0, 3600]) - -# THIS WORKS!!! diff --git a/pybamm/simulation.py b/pybamm/simulation.py index b25b76f859..da2bac841b 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -11,6 +11,8 @@ from datetime import timedelta import tqdm +from pybamm.expression_tree.operations.serialise import Serialise + def is_notebook(): try: @@ -1186,6 +1188,26 @@ def save(self, filename): with open(filename, "wb") as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) + def save_model(self, filename: str = None): + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. + """ + if self.built_model: + Serialise().save_model(self.built_model, filename=filename) + else: + raise NotImplementedError( + """ + PyBaMM can only serialise a discretised model. + Ensure the model has been built (e.g. run `solve()`) before saving. + """ + ) + def load_sim(filename): """Load a saved simulation""" From 6694cb1025156e13a26fc87af91f81835623dde5 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 1 Sep 2023 16:21:28 +0000 Subject: [PATCH 04/29] Serialised models can be plotted. Option to save mesh, variables and geometry Draft notebook written Added warning if variables are not provided and try to plot --- .../notebooks/models/saving_models.ipynb | 342 ++++++++++++++++++ pybamm/__init__.py | 6 +- pybamm/expression_tree/array.py | 2 +- pybamm/expression_tree/functions.py | 14 +- .../expression_tree/independent_variable.py | 4 +- .../expression_tree/operations/serialise.py | 127 ++++++- pybamm/expression_tree/unary_operators.py | 48 ++- pybamm/meshes/meshes.py | 23 ++ pybamm/meshes/one_dimensional_submeshes.py | 23 ++ pybamm/meshes/zero_dimensional_submesh.py | 17 + pybamm/models/base_model.py | 58 --- pybamm/models/event.py | 38 ++ .../full_battery_models/base_battery_model.py | 96 +++++ pybamm/plotting/quick_plot.py | 4 + pybamm/simulation.py | 25 +- 15 files changed, 748 insertions(+), 79 deletions(-) create mode 100644 docs/source/examples/notebooks/models/saving_models.ipynb diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb new file mode 100644 index 0000000000..94799bcc48 --- /dev/null +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -0,0 +1,342 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Saving PyBaMM models to file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Models which are discretised (i.e. ready to solve/ previously solved, see A DIFFERENT NOTEBOOK) can be serialised and saved to a JSON file, ready to be read in again either in PyBaMM, or a different modelling library. \n", + "\n", + "In the example below, we build and solve a basic DFN model, and then save the model out to `sim_model_example.json`, which should have appear in the 'models' directory." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install pybamm -q # install PyBaMM if it is not installed\n", + "import pybamm\n", + "\n", + "# do the example\n", + "dfn_model = pybamm.lithium_ion.DFN()\n", + "dfn_sim = pybamm.Simulation(dfn_model)\n", + "dfn_sim.solve([0, 3600])\n", + "\n", + "dfn_sim.save_model(\"sim_model_example\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This model file can then be read in and solved by choosing a solver, and running as below." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Recreate the pybamm model from the JSON file\n", + "new_dfn_model = pybamm.load_model(\"sim_model_example.json\")\n", + "\n", + "sim_reloaded = pybamm.Simulation(new_dfn_model) # PL: will this work if anything other than the default options are used? I guess not...\n", + "sim_reloaded.solve([0, 3600])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It would be nice to plot the results of the two models, to confirm that they are producing the same result.\n", + "\n", + "However, notice that running the code below generates an error stating that the model variables were not provided during the reading in of the model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "Variables not provided by the serialised model", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/pliggins/PyBaMM/docs/source/examples/notebooks/models/saving_models.ipynb Cell 7\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m plot_sim\u001b[39m.\u001b[39msolve([\u001b[39m0\u001b[39m, \u001b[39m3600\u001b[39m])\n\u001b[1;32m 6\u001b[0m sims\u001b[39m.\u001b[39mappend(plot_sim)\n\u001b[0;32m----> 8\u001b[0m pybamm\u001b[39m.\u001b[39;49mdynamic_plot(sims, time_unit\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mseconds\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", + "File \u001b[0;32m~/PyBaMM/pybamm/plotting/dynamic_plot.py:20\u001b[0m, in \u001b[0;36mdynamic_plot\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[39mCreates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[39marguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39m The 'QuickPlot' object that was created\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 19\u001b[0m kwargs_for_class \u001b[39m=\u001b[39m {k: v \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m kwargs\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m k \u001b[39m!=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[0;32m---> 20\u001b[0m plot \u001b[39m=\u001b[39m pybamm\u001b[39m.\u001b[39;49mQuickPlot(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs_for_class)\n\u001b[1;32m 21\u001b[0m plot\u001b[39m.\u001b[39mdynamic_plot(kwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m))\n\u001b[1;32m 22\u001b[0m \u001b[39mreturn\u001b[39;00m plot\n", + "File \u001b[0;32m~/PyBaMM/pybamm/plotting/quick_plot.py:163\u001b[0m, in \u001b[0;36mQuickPlot.__init__\u001b[0;34m(self, solutions, output_variables, labels, colors, linestyles, shading, figsize, n_rows, time_unit, spatial_unit, variable_limits)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[39m# check variables have been provided after any serialisation\u001b[39;00m\n\u001b[1;32m 162\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(\u001b[39mlen\u001b[39m(m\u001b[39m.\u001b[39mvariables) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mfor\u001b[39;00m m \u001b[39min\u001b[39;00m models):\n\u001b[0;32m--> 163\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mVariables not provided by the serialised model\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 165\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows \u001b[39m=\u001b[39m n_rows \u001b[39mor\u001b[39;00m \u001b[39mint\u001b[39m(\n\u001b[1;32m 166\u001b[0m \u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m np\u001b[39m.\u001b[39msqrt(\u001b[39mlen\u001b[39m(output_variables))\n\u001b[1;32m 167\u001b[0m )\n\u001b[1;32m 168\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_cols \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(np\u001b[39m.\u001b[39mceil(\u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows))\n", + "\u001b[0;31mAttributeError\u001b[0m: Variables not provided by the serialised model" + ] + } + ], + "source": [ + "dfn_models = [dfn_model, new_dfn_model]\n", + "sims = []\n", + "for model in dfn_models:\n", + " plot_sim = pybamm.Simulation(model)\n", + " plot_sim.solve([0, 3600])\n", + " sims.append(plot_sim)\n", + "\n", + "pybamm.dynamic_plot(sims, time_unit=\"seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To be able to plot the results from a serialised model, the mesh and model variables need to be saved alongside the model itself.\n", + "\n", + "To do this, set the `variables` option to `True` when saving the model as in the example below; notice how the models will now plot nicely." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b6b4db83fd054ba4be3ee279f7024c6a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='t', max=3600.0, step=36.0), Output()), _dom_classes=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# using the first simulation, save a new file which includes a list of the model variables\n", + "dfn_sim.save_model(\"sim_model_variables\", variables=True)\n", + "\n", + "# read the model back in\n", + "model_with_vars = pybamm.load_model(\"sim_model_variables.json\")\n", + "\n", + "# plot the pre- and post-serialisation models together to prove they behave the same\n", + "models = [dfn_model, model_with_vars]\n", + "sims = []\n", + "for model in models:\n", + " sim = pybamm.Simulation(model)\n", + " sim.solve([0, 3600])\n", + " sims.append(sim)\n", + "\n", + "pybamm.dynamic_plot(sims, time_unit=\"seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving from Model\n", + "\n", + "Alternatively, the model can be saved directly from the Model class.\n", + "\n", + "First set up the model, as explained in detail in the SPM NOTEBOOK" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create the model\n", + "spm_model = pybamm.lithium_ion.SPM()\n", + "\n", + "# set up and discretise ready to solve\n", + "geometry = spm_model.default_geometry\n", + "param = spm_model.default_parameter_values\n", + "param.process_model(spm_model)\n", + "param.process_geometry(geometry)\n", + "mesh = pybamm.Mesh(geometry, spm_model.default_submesh_types, spm_model.default_var_pts)\n", + "disc = pybamm.Discretisation(mesh, spm_model.default_spatial_methods)\n", + "disc.process_model(spm_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then save the model. Note that in this case the model variables and the mesh must be provided directly." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Serialise the spm_model, providing the varaibles and the mesh\n", + "spm_model.save_model(\"example_model\", variables=spm_model.variables, mesh=mesh)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now you can read the model back in, solve and plot." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b6df594b3af646599430ff322349b44f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read back in\n", + "new_spm_model = pybamm.load_model(\"example_model.json\")\n", + "\n", + "# select a solver and solve\n", + "new_spm_solver = new_spm_model.default_solver\n", + "new_spm_solution = new_spm_solver.solve(new_spm_model, [0, 3600])\n", + "\n", + "# plot the solution\n", + "new_spm_solution.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "The relevant papers for this notebook are:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1] Joel A. E. Andersson, Joris Gillis, Greg Horn, James B. Rawlings, and Moritz Diehl. CasADi – A software framework for nonlinear optimization and optimal control. Mathematical Programming Computation, 11(1):1–36, 2019. doi:10.1007/s12532-018-0139-4.\n", + "[2] Marc Doyle, Thomas F. Fuller, and John Newman. Modeling of galvanostatic charge and discharge of the lithium/polymer/insertion cell. Journal of the Electrochemical society, 140(6):1526–1533, 1993. doi:10.1149/1.2221597.\n", + "[3] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n", + "[4] Scott G. Marquis, Valentin Sulzer, Robert Timms, Colin P. Please, and S. Jon Chapman. An asymptotic derivation of a single particle model with electrolyte. Journal of The Electrochemical Society, 166(15):A3693–A3706, 2019. doi:10.1149/2.0341915jes.\n", + "[5] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n", + "\n" + ] + } + ], + "source": [ + "pybamm.print_citations()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pybamm/__init__.py b/pybamm/__init__.py index d7b957e1c9..c6376ec0b3 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -93,7 +93,6 @@ from .expression_tree.operations.jacobian import Jacobian from .expression_tree.operations.convert_to_casadi import CasadiConverter from .expression_tree.operations.unpack_symbols import SymbolUnpacker -from .models.base_model import load_model # # Model classes @@ -189,6 +188,11 @@ UserSupplied2DSubMesh, ) +# +# Serialisation +# +from .models.full_battery_models.base_battery_model import load_model + # # Spatial Methods # diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index c2fcbc4a11..270c546dbe 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -62,7 +62,7 @@ def _from_json(cls, snippet: dict): instance = cls.__new__(cls) if isinstance(snippet["entries"], dict): - matrix = csr_array( + matrix = csr_matrix( ( snippet["entries"]["data"], snippet["entries"]["row_indices"], diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 17732c7ba4..9743e7c754 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -212,8 +212,18 @@ def to_equation(self): eq_list.append(eq) return self._sympy_operator(*eq_list) - def to_json(self): - raise NotImplementedError() + def to_json( + self, + ): # PL: I think these ones might actually be present when you build your own function. + raise NotImplementedError( + "pybamm.Function: Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.Function: Please use a discretised model when reading in from JSON." + ) def simplified_function(func_class, child): diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 665bfdb344..fbf745dfca 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -67,10 +67,10 @@ def __init__(self): super().__init__("time") @classmethod - def _to_json(cls, snippet: dict): + def _from_json(cls, snippet: dict): instance = cls.__new__(cls) - instance.__init__("time") + instance.__init__() return instance diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index c88f32e602..e11049a35a 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -32,9 +32,42 @@ def default(self, node: dict): for c in node.children: node_dict["children"].append(self.default(c)) + if hasattr(node, "initial_condition"): # for ExplicitTimeIntegral + node_dict["initial_condition"] = self.default( + node.initial_condition + ) + + return node_dict + + if isinstance(node, pybamm.Event): + node_dict.update(node.to_json()) + node_dict["expression"] = self.default(node._expression) return node_dict - json_obj = json.JSONEncoder.default(self, node) + json_obj = json.JSONEncoder.default(self, node) # pragma: no cover + node_dict["json"] = json_obj + return node_dict + + class _MeshEncoder(json.JSONEncoder): + """Converts PyBaMM meshes into a JSON-serialisable format""" + + def default(self, node: dict): + node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + if isinstance(node, pybamm.Mesh): + node_dict.update(node.to_json()) + + node_dict["sub_meshes"] = {} + for k, v in node.items(): + if len(k) == 1 and "ghost cell" not in k[0]: + node_dict["sub_meshes"][k[0]] = self.default(v) + + return node_dict + + if isinstance(node, pybamm.SubMesh): + node_dict.update(node.to_json()) + return node_dict + + json_obj = json.JSONEncoder.default(self, node) # pragma: no cover node_dict["json"] = json_obj return node_dict @@ -43,7 +76,18 @@ class _Empty: pass - def save_model(self, model, filename=None): + class _EmptyDict(dict): + """A dummy dictionary class to aid deserialisation""" + + pass + + def save_model( + self, + model: pybamm.BaseBatteryModel, + mesh: pybamm.Mesh = None, + variables: pybamm.FuzzyDict = None, + filename: str = None, + ): """ Saves a discretised model to a JSON file. @@ -52,8 +96,14 @@ def save_model(self, model, filename=None): Parameters ---------- - model: : :class:`pybamm.BaseModel` + model: : :class:`pybamm.BaseBatteryModel` The discretised model to be saved + mesh: :class: `pybamm.Mesh`, optional + The mesh the model has been discretised over. Not neccesary to solve + the model when read in, but required to use pybamm's plotting tools. + variables: :class: `pybamm.FuzzyDict`, optional + The discretised model varaibles. Not necessary to solve a model, but + required to use pybamm's plotting tools. filename: str, optional The desired name of the JSON file. If no name is provided, one will be created based on the model name, and the current datetime. @@ -76,15 +126,29 @@ def save_model(self, model, filename=None): "concatenated_initial_conditions": self._SymbolEncoder().default( model._concatenated_initial_conditions ), + "events": [self._SymbolEncoder().default(event) for event in model.events], + "mass_matrix": self._SymbolEncoder().default(model.mass_matrix), + "mass_matrix_inv": self._SymbolEncoder().default(model.mass_matrix_inv), } + if mesh: + model_json["mesh"] = self._MeshEncoder().default(mesh) + + if variables: + model_json["geometry"] = dict(model._geometry) + model_json["variables"] = { + k: self._SymbolEncoder().default(v) for k, v in dict(variables).items() + } + if filename is None: filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M_%S") with open(filename + ".json", "w") as f: json.dump(model_json, f) - def load_model(self, filename: str, battery_model: BaseBatteryModel = None): + def load_model( + self, filename: str, battery_model: BaseBatteryModel = None + ) -> BaseBatteryModel: """ Loads a discretised, ready to solve model into PyBaMM. @@ -106,6 +170,11 @@ def load_model(self, filename: str, battery_model: BaseBatteryModel = None): PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override any model names within the file. If None, the function will look for the saved object path, present if the original model came from PyBaMM. + + Returns + ------- + :class: pybamm.BaseBatteryModel + A PyBaMM model object, of type specified either in the JSON or in `battery_model`. """ with open(filename, "r") as f: @@ -124,8 +193,35 @@ def load_model(self, filename: str, battery_model: BaseBatteryModel = None): "concatenated_initial_conditions": self._reconstruct_epression_tree( model_data["concatenated_initial_conditions"] ), + "events": [ + self._reconstruct_epression_tree(event) + for event in model_data["events"] + ], + "mass_matrix": self._reconstruct_epression_tree(model_data["mass_matrix"]), + "mass_matrix_inv": self._reconstruct_epression_tree( + model_data["mass_matrix_inv"] + ), } + recon_model_dict["geometry"] = ( + model_data["geometry"] if "geometry" in model_data.keys() else None + ) + + recon_model_dict["mesh"] = ( + self._reconstruct_mesh(model_data["mesh"]) + if "mesh" in model_data.keys() + else None + ) + + recon_model_dict["variables"] = ( + { + k: self._reconstruct_epression_tree(v) + for k, v in model_data["variables"].items() + } + if "variables" in model_data.keys() + else None + ) + if battery_model: return battery_model.deserialise(recon_model_dict) @@ -141,7 +237,6 @@ def load_model(self, filename: str, battery_model: BaseBatteryModel = None): def _get_pybamm_class(self, snippet: dict): """Find a pybamm class to initialise from object path""" - empty_class = self._Empty() parts = snippet["py/object"].split(".") try: module = importlib.import_module(".".join(parts[:-1])) @@ -149,7 +244,13 @@ def _get_pybamm_class(self, snippet: dict): print(ex) class_ = getattr(module, parts[-1]) - empty_class.__class__ = class_ + + try: + empty_class = self._Empty() + empty_class.__class__ = class_ + except: + empty_class = self._EmptyDict() + empty_class.__class__ = class_ return empty_class @@ -176,7 +277,21 @@ def _reconstruct_epression_tree(self, node: dict): for i, c in enumerate(node["children"]): child_obj = self._reconstruct_epression_tree(c) node["children"][i] = child_obj + elif "expression" in node: + expression_obj = self._reconstruct_epression_tree(node["expression"]) + node["expression"] = expression_obj obj = self._reconstruct_symbol(node) return obj + + def _reconstruct_mesh(self, node: dict): + """Reconstructs a Mesh object""" + if "sub_meshes" in node: + for k, v in node["sub_meshes"].items(): + sub_mesh = self._reconstruct_symbol(v) + node["sub_meshes"][k] = sub_mesh + + new_mesh = self._reconstruct_symbol(node) + + return new_mesh diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index a828b8442b..b4db6b6528 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -34,6 +34,21 @@ def __init__(self, name, child, domains=None): super().__init__(name, children=[child], domains=domains) self.child = self.children[0] + @classmethod + def _from_json(cls, name, snippet: dict): + """Use to instantiate when deserialising""" + + instance = cls.__new__(cls) + + super(UnaryOperator, instance).__init__( + name, + snippet["children"], + domains=snippet["domains"], + ) + instance.child = instance.children[0] + + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{}({!s})".format(self.name, self.child) @@ -99,6 +114,12 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("-", child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = super()._from_json("-", snippet) + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{}{!s}".format(self.name, self.child) @@ -353,11 +374,6 @@ class with a :class:`Matrix` def __init__(self, name, child, domains=None): super().__init__(name, child, domains) - def diff(self, variable): - """See :meth:`pybamm.Symbol.diff()`.""" - # We shouldn't need this - raise NotImplementedError - def to_json(self): raise NotImplementedError( "pybamm.SpatialOperator: Serialisation is only implemented for discretised models." @@ -944,12 +960,34 @@ def __init__(self, children, initial_condition): super().__init__("explicit time integral", children) self.initial_condition = initial_condition + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__(snippet["children"][0], snippet["initial_condition"]) + + return instance + def _unary_new_copy(self, child): return self.__class__(child, self.initial_condition) def is_constant(self): return False + def to_json(self): + """ + Convert ExplicitTimeIntegral to json for serialisation. + + Both `children` and `initial_condition` contain Symbols, and are therefore + dealt with by `pybamm.Serialise._SymbolEncoder.default()` directly. + """ + json_dict = { + "name": self.name, + "id": self.id, + } + + return json_dict + class BoundaryGradient(BoundaryOperator): """ diff --git a/pybamm/meshes/meshes.py b/pybamm/meshes/meshes.py index 4c86290a2f..182282319f 100644 --- a/pybamm/meshes/meshes.py +++ b/pybamm/meshes/meshes.py @@ -120,6 +120,21 @@ def __init__(self, geometry, submesh_types, var_pts): # add ghost meshes self.add_ghost_meshes() + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + super(Mesh, instance).__init__() + + instance.submesh_pts = snippet["submesh_pts"] + instance.base_domains = snippet["base_domains"] + + for k, v in snippet["sub_meshes"].items(): + instance[k] = v + + # instance.add_ghost_meshes() + + return instance + def __getitem__(self, domains): if isinstance(domains, str): domains = (domains,) @@ -216,6 +231,14 @@ def geometry(self): def geometry(self, geometry): self._geometry = geometry + def to_json(self): + json_dict = { + "submesh_pts": self.submesh_pts, + "base_domains": self.base_domains, + } + + return json_dict + class SubMesh: """ diff --git a/pybamm/meshes/one_dimensional_submeshes.py b/pybamm/meshes/one_dimensional_submeshes.py index 2beae6bc3a..147ed590cf 100644 --- a/pybamm/meshes/one_dimensional_submeshes.py +++ b/pybamm/meshes/one_dimensional_submeshes.py @@ -70,6 +70,17 @@ def read_lims(self, lims): return spatial_var, spatial_lims, tabs + def to_json(self): + json_dict = { + "edges": self.edges.tolist(), + "coord_sys": self.coord_sys, + } + + if hasattr(self, "tabs"): + json_dict["tabs"] = self.tabs + + return json_dict + class Uniform1DSubMesh(SubMesh1D): """ @@ -95,6 +106,18 @@ def __init__(self, lims, npts): super().__init__(edges, coord_sys=coord_sys, tabs=tabs) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + tabs = snippet["tabs"] if "tabs" in snippet.keys() else None + + super(Uniform1DSubMesh, instance).__init__( + np.array(snippet["edges"]), snippet["coord_sys"], tabs=tabs + ) + + return instance + class Exponential1DSubMesh(SubMesh1D): """ diff --git a/pybamm/meshes/zero_dimensional_submesh.py b/pybamm/meshes/zero_dimensional_submesh.py index 5b2f38e29f..dd4afe70fd 100644 --- a/pybamm/meshes/zero_dimensional_submesh.py +++ b/pybamm/meshes/zero_dimensional_submesh.py @@ -38,6 +38,23 @@ def __init__(self, position, npts=None): self.coord_sys = None self.npts = 1 + @classmethod + def _from_json(cls, snippet): + instance = cls.__new__(cls) + + instance.nodes = np.array(snippet["spatial_position"]) + instance.edges = np.array(snippet["spatial_position"]) + instance.coord_sys = None + instance.npts = 1 + + return instance + def add_ghost_meshes(self): # No ghost meshes to be added to this class pass + + def to_json(self): + json_dict = { + "spatial_position": self.nodes.tolist(), + } + return json_dict diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index f407a15c08..41192dbe1f 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -10,7 +10,6 @@ import pybamm from pybamm.expression_tree.operations.latexify import Latexify -from pybamm.expression_tree.operations.serialise import Serialise class BaseModel: @@ -124,35 +123,6 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None - # PL: Next up, how to pass in the non-standard variables, if necessary. - @classmethod - def deserialise(cls, properties: dict): - """ - Create a model instance from a serialised object. - """ - instance = cls.__new__(cls) - - instance.__init__(name=properties["name"]) - - # Initialise model with stored variables - instance._options = properties["options"] # For information only - instance._concatenated_rhs = properties["concatenated_rhs"] - instance._concatenated_algebraic = properties["concatenated_algebraic"] - instance._concatenated_initial_conditions = properties[ - "concatenated_initial_conditions" - ] - - # Model has already been discretised - instance.is_discretised = True - - instance.len_rhs = instance.concatenated_rhs.size - instance.len_alg = instance.concatenated_algebraic.size - instance.len_rhs_and_alg = instance.len_rhs + instance.len_alg - - instance.bounds = properties["bounds"] - - return instance - @property def name(self): return self._name @@ -1140,34 +1110,6 @@ def process_parameters_and_discretise(self, symbol, parameter_values, disc): return disc_symbol - def save_model(self, filename=None): - """ - Write out a discretised model to a JSON file - - Parameters - ---------- - filename: str, optional - The desired name of the JSON file. If no name is provided, one will be created - based on the model name, and the current datetime. - """ - Serialise().save_model(self, filename=filename) - - -def load_model(filename, battery_model: BaseModel = None): - """ - Load in a saved model from a JSON file - - Parameters - ---------- - filename: str - Path to the JSON file containing the serialised model file - battery_model: :class: pybamm.BaseBatteryModel, optional - PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override - any model names within the file. If None, the function will look for the saved object - path, present if the original model came from PyBaMM. - """ - return Serialise().load_model(filename, battery_model) - # helper functions for finding symbols def find_symbol_in_tree(tree, name): diff --git a/pybamm/models/event.py b/pybamm/models/event.py index e93262641d..105106c470 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -46,6 +46,28 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._expression = expression self._event_type = event_type + @classmethod + def _from_json(cls, snippet: dict): + """ + Reconstructs an Event instance during deserialisation of a JSON file. + + Parameters + ---------- + snippet: dict + Contains the information needed to reconstruct a specific instance. + Should contain "name", "expression" and "event_type". + """ + + instance = cls.__new__(cls) + + instance.__init__( + snippet["name"], + snippet["expression"], + event_type=EventType(snippet["event_type"][1]), + ) + + return instance + def evaluate(self, t=None, y=None, y_dot=None, inputs=None): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` @@ -66,3 +88,19 @@ def expression(self): @property def event_type(self): return self._event_type + + def to_json(self): + """ + Method to serialise an Event object into JSON. + + The expression is written out seperately, + See :meth:`pybamm.Serialise._SymbolEncoder.default()` + """ + + # event_type contains string name, for JSON readability, and value for deserialisation. + json_dict = { + "name": self._name, + "event_type": [str(self._event_type), self._event_type.value], + } + + return json_dict diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index ad36786381..841bf53a81 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -17,6 +17,9 @@ def represents_positive_integer(s): return val > 0 +from pybamm.expression_tree.operations.serialise import Serialise + + class BatteryModelOptions(pybamm.FuzzyDict): """ Attributes @@ -799,6 +802,66 @@ def __init__(self, options=None, name="Unnamed battery model"): super().__init__(name) self.options = options + # PL: Next up, how to pass in the non-standard variables, if necessary. + @classmethod + def deserialise( + cls, properties: dict + ): # PL: maybe option up here as output_mesh=true to output a tuple, (model, mesh) rather than just updating the variables and leaving it at that. + """ + Create a model instance from a serialised object. + """ + instance = cls.__new__(cls) + + # append the model name with _saved to differentiate + instance.__init__( + options=properties["options"], name=properties["name"] + "_saved" + ) + + # Initialise model with stored variables that have already been discretised + instance._concatenated_rhs = properties["concatenated_rhs"] + instance._concatenated_algebraic = properties["concatenated_algebraic"] + instance._concatenated_initial_conditions = properties[ + "concatenated_initial_conditions" + ] + + instance.len_rhs = instance.concatenated_rhs.size + instance.len_alg = instance.concatenated_algebraic.size + instance.len_rhs_and_alg = instance.len_rhs + instance.len_alg + + instance.bounds = properties["bounds"] + instance.events = properties["events"] + instance.mass_matrix = properties["mass_matrix"] + instance.mass_matrix_inv = properties["mass_matrix_inv"] + + # add optional properties not required for model to solve + if properties["variables"]: + instance._variables = pybamm.FuzzyDict(properties["variables"]) + + # assign meshes to each variable + for var in instance._variables.values(): + if var.domain != []: + var.mesh = properties["mesh"][var.domain] + else: + var.mesh = None + + if var.domains["secondary"] != []: + var.secondary_mesh = properties["mesh"][var.domains["secondary"]] + else: + var.secondary_mesh = None + + instance._geometry = pybamm.Geometry(properties["geometry"]) + else: + # Delete the default variables which have not been discretised + instance._variables = pybamm.FuzzyDict({}) + + # PL: Simulation(new_model, new_mesh) + # doesn't work because the model is already discretised, you can't give it a new mesh. + + # Model has already been discretised + instance.is_discretised = True + + return instance + @property def default_geometry(self): return pybamm.battery_geometry(options=self.options) @@ -1379,3 +1442,36 @@ def set_soc_variables(self): This function is overriden by the base battery models """ pass + + def save_model(self, filename=None, mesh=None, variables=None): + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. + """ + if variables and not mesh: + raise ValueError( + "Serialisation: Please provide the mesh if variables are required" + ) + + Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) + + +def load_model(filename, battery_model: BaseBatteryModel = None): + """ + Load in a saved model from a JSON file + + Parameters + ---------- + filename: str + Path to the JSON file containing the serialised model file + battery_model: :class: pybamm.BaseBatteryModel, optional + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override + any model names within the file. If None, the function will look for the saved object + path, present if the original model came from PyBaMM. + """ + return Serialise().load_model(filename, battery_model) diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 03bfeeccd4..3f55648225 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -154,6 +154,10 @@ def __init__( f"No default output variables provided for {models[0].name}" ) + # check variables have been provided after any serialisation + if any(len(m.variables) == 0 for m in models): + raise AttributeError(f"Variables not provided by the serialised model") + self.n_rows = n_rows or int( len(output_variables) // np.sqrt(len(output_variables)) ) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index da2bac841b..4a71e819bd 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -1188,18 +1188,35 @@ def save(self, filename): with open(filename, "wb") as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) - def save_model(self, filename: str = None): + def save_model( + self, + filename: str = None, + mesh: bool = False, + variables: bool = False, + ): """ Write out a discretised model to a JSON file Parameters ---------- + mesh: bool + The mesh used to discretise the model. If false, plotting tools will not + be available when the model is read back in and solved. + variables: bool + The discretised variables. Not required to solve a model, but if false + tools will not be availble. Will automatically save meshes as well, required + for plotting tools. filename: str, optional - The desired name of the JSON file. If no name is provided, one will be created - based on the model name, and the current datetime. + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. """ + mesh = self.mesh if (mesh or variables) else None + variables = self.built_model.variables if variables else None + if self.built_model: - Serialise().save_model(self.built_model, filename=filename) + Serialise().save_model( + self.built_model, filename=filename, mesh=mesh, variables=variables + ) else: raise NotImplementedError( """ From 7fadee320094b57eb8c1387cd90746b06a1175ed Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 19 Sep 2023 16:45:54 +0000 Subject: [PATCH 05/29] Add unit tests for to_json() --- pybamm/expression_tree/array.py | 3 +- pybamm/expression_tree/interpolant.py | 4 +- pybamm/expression_tree/state_vector.py | 1 - pybamm/expression_tree/variable.py | 4 +- tests/unit/test_expression_tree/test_array.py | 18 ++++++ .../test_binary_operators.py | 59 +++++++++++++++++++ .../test_expression_tree/test_broadcasts.py | 6 ++ .../test_concatenations.py | 59 +++++++++++++++++++ .../test_input_parameter.py | 14 +++++ .../test_expression_tree/test_interpolant.py | 8 +++ .../unit/test_expression_tree/test_matrix.py | 24 ++++++++ .../test_expression_tree/test_parameter.py | 12 ++++ .../unit/test_expression_tree/test_scalar.py | 6 ++ .../test_expression_tree/test_state_vector.py | 26 ++++++++ .../unit/test_expression_tree/test_symbol.py | 20 +++++++ .../test_unary_operators.py | 52 ++++++++++++++++ .../test_expression_tree/test_variable.py | 5 ++ 17 files changed, 316 insertions(+), 5 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 270c546dbe..0fc74f3209 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -3,7 +3,7 @@ # import numpy as np import sympy -from scipy.sparse import csr_matrix, issparse, csr_array +from scipy.sparse import csr_matrix, issparse import pybamm @@ -176,7 +176,6 @@ def to_json(self): "id": self.id, "domains": self.domains, "entries": matrix, - # "entries_string": self.entries_string.decode(), } return json_dict diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 16bbe88d7e..5234e5e927 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -292,4 +292,6 @@ def _function_evaluate(self, evaluated_children): raise ValueError("Invalid dimension: {0}".format(self.dimension)) def to_json(self): - raise NotImplementedError + raise NotImplementedError( + "pybamm.Interpolant: Serialisation is only implemented for discretised models." + ) diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 9a414dc049..72b1ed18a5 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -227,7 +227,6 @@ def to_json(self): for y in self.y_slices ], "evaluation_array": list(self.evaluation_array), - # "children": self.children, # might not need this, the anytree exporter handles children I think } return json_dict diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 8aa2b1d707..8fe655d513 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -132,7 +132,9 @@ def to_equation(self): def to_json( self, ): - raise NotImplementedError + raise NotImplementedError( + "pybamm.Variable: Serialisation is only implemented for discretised models." + ) class Variable(VariableBase): diff --git a/tests/unit/test_expression_tree/test_array.py b/tests/unit/test_expression_tree/test_array.py index da79dbb6e0..6ef1669270 100644 --- a/tests/unit/test_expression_tree/test_array.py +++ b/tests/unit/test_expression_tree/test_array.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import numpy as np import sympy @@ -41,6 +42,23 @@ def test_to_equation(self): pybamm.Array([1, 2]).to_equation(), sympy.Array([[1.0], [2.0]]) ) + def test_to_json_array(self): + arr = pybamm.Array(np.array([1, 2, 3])) + self.assertEqual( + arr.to_json(), + { + "name": "Array of shape (3, 1)", + "id": mock.ANY, # The value of the ID will change, but want to check it is present + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "entries": [[1.0], [2.0], [3.0]], + }, + ) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 6acd7c41b0..9a66e3a639 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import numpy as np import sympy @@ -10,6 +11,13 @@ import pybamm +EMPTY_DOMAINS = { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], +} + class TestBinaryOperators(TestCase): def test_binary_operator(self): @@ -770,6 +778,57 @@ def test_to_equation(self): # Test NotEqualHeaviside self.assertEqual(pybamm.NotEqualHeaviside(2, 4).to_equation(), True) + def test_to_json(self): + # Test Addition + self.assertEqual( + pybamm.Addition(2, 4).to_json(), + { + "name": "+", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + }, + ) + + # Test Power + self.assertEqual( + pybamm.Power(7, 2).to_json(), + { + "name": "**", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + }, + ) + + # Test Division + self.assertEqual( + pybamm.Division(10, 5).to_json(), + { + "name": "/", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + }, + ) + + # Test EqualHeaviside + self.assertEqual( + pybamm.EqualHeaviside(2, 4).to_json(), + { + "name": "<=", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + }, + ) + + # Test notEqualHeaviside + self.assertEqual( + pybamm.NotEqualHeaviside(2, 4).to_json(), + { + "name": "<", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + }, + ) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_broadcasts.py b/tests/unit/test_expression_tree/test_broadcasts.py index 81d1210229..be6772af2d 100644 --- a/tests/unit/test_expression_tree/test_broadcasts.py +++ b/tests/unit/test_expression_tree/test_broadcasts.py @@ -350,6 +350,12 @@ def test_diff(self): self.assertIsInstance(d, pybamm.Scalar) self.assertEqual(d.evaluate(y=y), 0) + def test_to_json(self): + a = pybamm.StateVector(slice(0, 1)) + b = pybamm.PrimaryBroadcast(a, "separator") + with self.assertRaises(NotImplementedError): + b.to_json() + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index df5add0f98..f846220a77 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -2,6 +2,7 @@ # Tests for the Concatenation class and subclasses # import unittest +import unittest.mock as mock from tests import TestCase import numpy as np @@ -382,6 +383,64 @@ def test_to_equation(self): # Test concat_sym self.assertEqual(pybamm.Concatenation(a, b).to_equation(), func_symbol) + def test_to_json(self): + # test DomainConcatenation + mesh = get_mesh_for_testing() + a = pybamm.Symbol("a", domain=["negative electrode"]) + b = pybamm.Symbol("b", domain=["separator", "positive electrode"]) + conc = pybamm.DomainConcatenation([a, b], mesh) + + json_dict = { + "name": "domain_concatenation", + "id": mock.ANY, + "domains": { + "primary": ["negative electrode", "separator", "positive electrode"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "slices": { + "negative electrode": [{"start": 0, "stop": 40, "step": None}], + "separator": [{"start": 40, "stop": 65, "step": None}], + "positive electrode": [{"start": 65, "stop": 100, "step": None}], + }, + "size": 100, + "children_slices": [ + {"negative electrode": [{"start": 0, "stop": 40, "step": None}]}, + { + "separator": [{"start": 0, "stop": 25, "step": None}], + "positive electrode": [{"start": 25, "stop": 60, "step": None}], + }, + ], + "secondary_dimensions_npts": 1, + } + + self.assertEqual( + conc.to_json(), + json_dict, + ) + + # test NumpyConcatenation + y = np.linspace(0, 1, 15)[:, np.newaxis] + a_np = pybamm.Vector(y[:5]) + b_np = pybamm.Vector(y[5:9]) + c_np = pybamm.Vector(y[9:]) + conc_np = pybamm.NumpyConcatenation(a_np, b_np, c_np) + + self.assertEqual( + conc_np.to_json(), + { + "name": "numpy_concatenation", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + }, + ) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_input_parameter.py b/tests/unit/test_expression_tree/test_input_parameter.py index 82dd06fee5..48ad2c441f 100644 --- a/tests/unit/test_expression_tree/test_input_parameter.py +++ b/tests/unit/test_expression_tree/test_input_parameter.py @@ -6,6 +6,8 @@ import pybamm import unittest +import unittest.mock as mock + class TestInputParameter(TestCase): def test_input_parameter_init(self): @@ -49,6 +51,18 @@ def test_errors(self): with self.assertRaises(KeyError): a.evaluate() + def test_to_json(self): + a = pybamm.InputParameter("a") + + json_dict = { + "name": "a", + "id": mock.ANY, + "domain": [], + "expected_size": 1, + } + + self.assertEqual(a.to_json(), json_dict) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index e1547ef3fc..b6c195eccc 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -325,6 +325,14 @@ def test_processing(self): self.assertEqual(interp, interp.new_copy()) + def test_to_json(self): + x = np.linspace(0, 1, 200) + y = pybamm.StateVector(slice(0, 2)) + interp = pybamm.Interpolant(x, 2 * x, y) + + with self.assertRaises(NotImplementedError): + interp.to_json() + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 39aba44483..8e466818f1 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -4,8 +4,10 @@ from tests import TestCase import pybamm import numpy as np +from scipy.sparse import csr_matrix import unittest +import unittest.mock as mock class TestMatrix(TestCase): @@ -38,6 +40,28 @@ def test_matrix_operations(self): (self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]]) ) + def test_to_json_matrix(self): + arr = pybamm.Matrix(csr_matrix([[0, 1, 0, 0], [0, 0, 0, 1]])) + self.assertEqual( + arr.to_json(), + { + "name": "Sparse Matrix (2, 4)", + "id": mock.ANY, # The value of the ID will change, but want to check it is present + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "entries": { + "column_pointers": [0, 1, 2], + "data": [1.0, 1.0], + "row_indices": [1, 3], + "shape": (2, 4), + }, + }, + ) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index f67ee2dd62..6001d8906b 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -31,6 +31,12 @@ def test_to_equation(self): # Test name self.assertEqual(func1.to_equation(), sympy.Symbol("test_name")) + def test_to_json(self): + func = pybamm.Parameter("test_string") + + with self.assertRaises(NotImplementedError): + func.to_json() + class TestFunctionParameter(TestCase): def test_function_parameter_init(self): @@ -109,6 +115,12 @@ def test_function_parameter_to_equation(self): func1.print_name = None self.assertEqual(func1.to_equation(), sympy.Symbol("func")) + def test_to_json(self): + func = pybamm.FunctionParameter("test", {"x": pybamm.Scalar(1)}) + + with self.assertRaises(NotImplementedError): + func.to_json() + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_scalar.py b/tests/unit/test_expression_tree/test_scalar.py index af0a6e80ca..9d990e354d 100644 --- a/tests/unit/test_expression_tree/test_scalar.py +++ b/tests/unit/test_expression_tree/test_scalar.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import pybamm @@ -44,6 +45,11 @@ def test_copy(self): b = a.create_copy() self.assertEqual(a, b) + def test_to_json(self): + a = pybamm.Scalar(5) + + self.assertEqual(a.to_json(), {"name": "5.0", "id": mock.ANY, "value": 5.0}) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_state_vector.py b/tests/unit/test_expression_tree/test_state_vector.py index d401487264..0165d1d512 100644 --- a/tests/unit/test_expression_tree/test_state_vector.py +++ b/tests/unit/test_expression_tree/test_state_vector.py @@ -6,6 +6,7 @@ import numpy as np import unittest +import unittest.mock as mock class TestStateVector(TestCase): @@ -62,6 +63,31 @@ def test_failure(self): with self.assertRaisesRegex(TypeError, "all y_slices must be slice objects"): pybamm.StateVector(slice(0, 10), 1) + def test_to_json(self): + array = np.array([1, 2, 3, 4, 5]) + sv = pybamm.StateVector(slice(0, 10), evaluation_array=array) + + json_dict = { + "name": "y[0:10]", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "y_slice": [ + { + "start": 0, + "stop": 10, + "step": None, + } + ], + "evaluation_array": [1, 2, 3, 4, 5], + } + + self.assertEqual(sv.to_json(), json_dict) + class TestStateVectorDot(TestCase): def test_evaluate(self): diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 3a74375ce7..17c5f0a02f 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -4,6 +4,7 @@ from tests import TestCase import os import unittest +import unittest.mock as mock import numpy as np from scipy.sparse import csr_matrix, coo_matrix @@ -486,6 +487,25 @@ def test_numpy_array_ufunc(self): x = pybamm.Symbol("x") self.assertEqual(np.exp(x), pybamm.exp(x)) + def test_to_json(self): + symc1 = pybamm.Symbol("child1", domain=["domain_1"]) + symc2 = pybamm.Symbol("child2", domain=["domain_2"]) + symp = pybamm.Symbol("parent", domain=["domain_3"], children=[symc1, symc2]) + + self.assertEqual( + symp.to_json(), + { + "name": "parent", + "id": mock.ANY, + "domains": { + "primary": ["domain_3"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + }, + ) + class TestIsZero(TestCase): def test_is_scalar_zero(self): diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index b0513c974b..e8fc7c7be0 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -3,6 +3,7 @@ # import unittest from tests import TestCase +import unittest.mock as mock import numpy as np import sympy @@ -668,6 +669,57 @@ def test_explicit_time_integral(self): self.assertEqual(expr.new_copy(), expr) self.assertFalse(expr.is_constant()) + def test_to_json(self): + # UnaryOperator + a = pybamm.Symbol("a", domain=["test"]) + un = pybamm.UnaryOperator("unary test", a) + self.assertEqual( + un.to_json(), + { + "name": "unary test", + "id": mock.ANY, + "domains": { + "primary": ["test"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + }, + ) + + # Index + vec = pybamm.StateVector(slice(0, 5)) + ind = pybamm.Index(vec, 3) + self.assertEqual( + ind.to_json(), + { + "name": "Index[3]", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "check_size": False, + }, + ) + + # SpatialOperator + spatial_vec = pybamm.SpatialOperator("name", vec) + with self.assertRaises(NotImplementedError): + spatial_vec.to_json() + + # ExplicitTimeIntegral + expr = pybamm.ExplicitTimeIntegral(pybamm.Parameter("param"), pybamm.Scalar(1)) + self.assertEqual( + expr.to_json(), + { + "name": "explicit time integral", + "id": mock.ANY, + }, + ) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_variable.py b/tests/unit/test_expression_tree/test_variable.py index be791903e2..f4f3029c75 100644 --- a/tests/unit/test_expression_tree/test_variable.py +++ b/tests/unit/test_expression_tree/test_variable.py @@ -63,6 +63,11 @@ def test_to_equation(self): # Test name self.assertEqual(pybamm.Variable("name").to_equation(), sympy.Symbol("name")) + def test_to_json(self): + func = pybamm.Variable("test_string") + with self.assertRaises(NotImplementedError): + func.to_json() + class TestVariableDot(TestCase): def test_variable_init(self): From 25cb002789ab71212653d0880b20b60d0a7f6618 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 21 Sep 2023 16:01:30 +0000 Subject: [PATCH 06/29] Allow saving of geometry where symbols are dict keys Put warning in for BaseModel - atm requires more model information to re-create the model. --- .../notebooks/models/saving_models.ipynb | 25 ++-- .../expression_tree/operations/serialise.py | 118 ++++++++++++++++-- pybamm/models/base_model.py | 11 ++ 3 files changed, 129 insertions(+), 25 deletions(-) diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb index 94799bcc48..c3c9c90ea8 100644 --- a/docs/source/examples/notebooks/models/saving_models.ipynb +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Models which are discretised (i.e. ready to solve/ previously solved, see A DIFFERENT NOTEBOOK) can be serialised and saved to a JSON file, ready to be read in again either in PyBaMM, or a different modelling library. \n", + "Models which are discretised (i.e. ready to solve/ previously solved, see [this notebook](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/spatial_methods/finite-volumes.ipynb) for more information on the pybamm.Discretisation class) can be serialised and saved to a JSON file, ready to be read in again either in PyBaMM, or a different modelling library. \n", "\n", "In the example below, we build and solve a basic DFN model, and then save the model out to `sim_model_example.json`, which should have appear in the 'models' directory." ] @@ -25,9 +25,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } @@ -59,7 +56,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -98,7 +95,7 @@ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/home/pliggins/PyBaMM/docs/source/examples/notebooks/models/saving_models.ipynb Cell 7\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m plot_sim\u001b[39m.\u001b[39msolve([\u001b[39m0\u001b[39m, \u001b[39m3600\u001b[39m])\n\u001b[1;32m 6\u001b[0m sims\u001b[39m.\u001b[39mappend(plot_sim)\n\u001b[0;32m----> 8\u001b[0m pybamm\u001b[39m.\u001b[39;49mdynamic_plot(sims, time_unit\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mseconds\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", "File \u001b[0;32m~/PyBaMM/pybamm/plotting/dynamic_plot.py:20\u001b[0m, in \u001b[0;36mdynamic_plot\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[39mCreates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[39marguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39m The 'QuickPlot' object that was created\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 19\u001b[0m kwargs_for_class \u001b[39m=\u001b[39m {k: v \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m kwargs\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m k \u001b[39m!=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[0;32m---> 20\u001b[0m plot \u001b[39m=\u001b[39m pybamm\u001b[39m.\u001b[39;49mQuickPlot(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs_for_class)\n\u001b[1;32m 21\u001b[0m plot\u001b[39m.\u001b[39mdynamic_plot(kwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m))\n\u001b[1;32m 22\u001b[0m \u001b[39mreturn\u001b[39;00m plot\n", - "File \u001b[0;32m~/PyBaMM/pybamm/plotting/quick_plot.py:163\u001b[0m, in \u001b[0;36mQuickPlot.__init__\u001b[0;34m(self, solutions, output_variables, labels, colors, linestyles, shading, figsize, n_rows, time_unit, spatial_unit, variable_limits)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[39m# check variables have been provided after any serialisation\u001b[39;00m\n\u001b[1;32m 162\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(\u001b[39mlen\u001b[39m(m\u001b[39m.\u001b[39mvariables) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mfor\u001b[39;00m m \u001b[39min\u001b[39;00m models):\n\u001b[0;32m--> 163\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mVariables not provided by the serialised model\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 165\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows \u001b[39m=\u001b[39m n_rows \u001b[39mor\u001b[39;00m \u001b[39mint\u001b[39m(\n\u001b[1;32m 166\u001b[0m \u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m np\u001b[39m.\u001b[39msqrt(\u001b[39mlen\u001b[39m(output_variables))\n\u001b[1;32m 167\u001b[0m )\n\u001b[1;32m 168\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_cols \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(np\u001b[39m.\u001b[39mceil(\u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows))\n", + "File \u001b[0;32m~/PyBaMM/pybamm/plotting/quick_plot.py:159\u001b[0m, in \u001b[0;36mQuickPlot.__init__\u001b[0;34m(self, solutions, output_variables, labels, colors, linestyles, shading, figsize, n_rows, time_unit, spatial_unit, variable_limits)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[39m# check variables have been provided after any serialisation\u001b[39;00m\n\u001b[1;32m 158\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(\u001b[39mlen\u001b[39m(m\u001b[39m.\u001b[39mvariables) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mfor\u001b[39;00m m \u001b[39min\u001b[39;00m models):\n\u001b[0;32m--> 159\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mVariables not provided by the serialised model\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 161\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows \u001b[39m=\u001b[39m n_rows \u001b[39mor\u001b[39;00m \u001b[39mint\u001b[39m(\n\u001b[1;32m 162\u001b[0m \u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m np\u001b[39m.\u001b[39msqrt(\u001b[39mlen\u001b[39m(output_variables))\n\u001b[1;32m 163\u001b[0m )\n\u001b[1;32m 164\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_cols \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(np\u001b[39m.\u001b[39mceil(\u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows))\n", "\u001b[0;31mAttributeError\u001b[0m: Variables not provided by the serialised model" ] } @@ -131,7 +128,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b6b4db83fd054ba4be3ee279f7024c6a", + "model_id": "eaf8ae8b8dd84a99b8b1aecfc132ad83", "version_major": 2, "version_minor": 0 }, @@ -145,7 +142,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -179,7 +176,9 @@ "\n", "Alternatively, the model can be saved directly from the Model class.\n", "\n", - "First set up the model, as explained in detail in the SPM NOTEBOOK" + "Note that at the moment, only models derived from the BaseBatteryModel class can be serialised; those built from scratch using pybamm.BaseModel() are currently unsupported.\n", + "\n", + "First set up the model, as explained in detail for the [SPM](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/models/SPM.ipynb)." ] }, { @@ -190,7 +189,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -244,7 +243,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b6df594b3af646599430ff322349b44f", + "model_id": "a1c0b22c969b45858361b7e9de264e76", "version_major": 2, "version_minor": 0 }, @@ -258,7 +257,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -289,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index e11049a35a..d27f770451 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -5,6 +5,7 @@ import json import importlib import numpy as np +import re from typing import TYPE_CHECKING @@ -135,7 +136,9 @@ def save_model( model_json["mesh"] = self._MeshEncoder().default(mesh) if variables: - model_json["geometry"] = dict(model._geometry) + model_json["geometry"] = self._deconstruct_pybamm_dicts( + dict(model._geometry) + ) model_json["variables"] = { k: self._SymbolEncoder().default(v) for k, v in dict(variables).items() } @@ -184,27 +187,29 @@ def load_model( "name": model_data["name"], "options": model_data["options"], "bounds": tuple(np.array(bound) for bound in model_data["bounds"]), - "concatenated_rhs": self._reconstruct_epression_tree( + "concatenated_rhs": self._reconstruct_expression_tree( model_data["concatenated_rhs"] ), - "concatenated_algebraic": self._reconstruct_epression_tree( + "concatenated_algebraic": self._reconstruct_expression_tree( model_data["concatenated_algebraic"] ), - "concatenated_initial_conditions": self._reconstruct_epression_tree( + "concatenated_initial_conditions": self._reconstruct_expression_tree( model_data["concatenated_initial_conditions"] ), "events": [ - self._reconstruct_epression_tree(event) + self._reconstruct_expression_tree(event) for event in model_data["events"] ], - "mass_matrix": self._reconstruct_epression_tree(model_data["mass_matrix"]), - "mass_matrix_inv": self._reconstruct_epression_tree( + "mass_matrix": self._reconstruct_expression_tree(model_data["mass_matrix"]), + "mass_matrix_inv": self._reconstruct_expression_tree( model_data["mass_matrix_inv"] ), } recon_model_dict["geometry"] = ( - model_data["geometry"] if "geometry" in model_data.keys() else None + self._reconstruct_geometry(model_data["geometry"]) + if "geometry" in model_data.keys() + else None ) recon_model_dict["mesh"] = ( @@ -215,7 +220,7 @@ def load_model( recon_model_dict["variables"] = ( { - k: self._reconstruct_epression_tree(v) + k: self._reconstruct_expression_tree(v) for k, v in model_data["variables"].items() } if "variables" in model_data.keys() @@ -235,6 +240,8 @@ def load_model( """ ) + # Helper functions + def _get_pybamm_class(self, snippet: dict): """Find a pybamm class to initialise from object path""" parts = snippet["py/object"].split(".") @@ -254,13 +261,55 @@ def _get_pybamm_class(self, snippet: dict): return empty_class + def _deconstruct_pybamm_dicts(self, dct: dict): + """ + Converts dictionaries which contain pybamm classes as keys + into a json serialisable format. + + Dictionary keys present as pybamm objects are given a seperate key + as "symbol_" to store the dictionary required to reconstruct + a symbol, and their seperate key is used in the original dictionary. E.G: + + {'rod': + {SpatialVariable(name='spat_var'): {"min":0.0, "max":2.0} } + } + + converts to + + {'rod': + {'symbol_spat_var': {"min":0.0, "max":2.0} }, + 'spat_var': + {"py/object":pybamm....} + } + + Dictionaries which don't contain pybamm symbols are returned unchanged. + """ + + def nested_convert(obj): + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if isinstance(k, pybamm.Symbol): + new_k = self._SymbolEncoder().default(k) + new_dict["symbol_" + new_k["name"]] = new_k + k = new_k["name"] + new_dict[k] = nested_convert(v) + return new_dict + return obj + + try: + _ = json.dumps(dct) + return dict(dct) + except TypeError: # dct must contain pybamm objects + return nested_convert(dct) + def _reconstruct_symbol(self, dct: dict): """Reconstruct an individual pybamm Symbol""" symbol_class = self._get_pybamm_class(dct) symbol = symbol_class._from_json(dct) return symbol - def _reconstruct_epression_tree(self, node: dict): + def _reconstruct_expression_tree(self, node: dict): """ Loop through an expression tree creating pybamm Symbol classes @@ -275,10 +324,10 @@ def _reconstruct_epression_tree(self, node: dict): """ if "children" in node: for i, c in enumerate(node["children"]): - child_obj = self._reconstruct_epression_tree(c) + child_obj = self._reconstruct_expression_tree(c) node["children"][i] = child_obj elif "expression" in node: - expression_obj = self._reconstruct_epression_tree(node["expression"]) + expression_obj = self._reconstruct_expression_tree(node["expression"]) node["expression"] = expression_obj obj = self._reconstruct_symbol(node) @@ -295,3 +344,48 @@ def _reconstruct_mesh(self, node: dict): new_mesh = self._reconstruct_symbol(node) return new_mesh + + def _reconstruct_geometry(self, obj: dict): + """ + pybamm.Geometry can contain PyBaMM symbols as dictionary keys. + + Converts + {"rod": + {"symbol_spat_var": + {"min":0.0, "max":2.0} }, + "spat_var": + {"py/object":"pybamm...."} + } + + from an exported JSON file to + + {"rod": + {SpatialVariable(name="spat_var"): {"min":0.0, "max":2.0} } + } + """ + + def recurse(obj): + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if "symbol_" in k: + new_dict[k] = self._reconstruct_symbol(v) + elif isinstance(v, dict): + new_dict[k] = recurse(v) + else: + new_dict[k] = v + + pattern = re.compile("symbol_") + symbol_keys = {k: v for k, v in new_dict.items() if pattern.match(k)} + + # rearrange the dictionary to make pybamm objects the dictionary keys + if symbol_keys: + for k, v in symbol_keys.items(): + new_dict[v] = new_dict[k.lstrip("symbol_")] + del new_dict[k] + del new_dict[k.lstrip("symbol_")] + + return new_dict + return obj + + return recurse(obj) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 41192dbe1f..80dda2808d 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -123,6 +123,12 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None + @classmethod + def deserialise(cls, properties: dict): + raise NotImplementedError( + "BaseModel: Serialisation not yet implemented for non-battery models." + ) + @property def name(self): return self._name @@ -1110,6 +1116,11 @@ def process_parameters_and_discretise(self, symbol, parameter_values, disc): return disc_symbol + def save_model(self, filename=None, mesh=None, variables=None): + raise NotImplementedError( + "BaseModel: Serialisation not yet implemented for non-battery models." + ) + # helper functions for finding symbols def find_symbol_in_tree(tree, name): From efa78887a584f2733969a6a3a1b28e327897fb8d Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 21 Sep 2023 16:36:26 +0000 Subject: [PATCH 07/29] Add _from_json tests for symbols without children --- tests/unit/test_expression_tree/test_array.py | 33 +++++++++------- .../test_expression_tree/test_broadcasts.py | 2 +- .../test_input_parameter.py | 6 ++- .../test_expression_tree/test_interpolant.py | 2 +- .../unit/test_expression_tree/test_matrix.py | 39 ++++++++++--------- .../test_expression_tree/test_parameter.py | 4 +- .../unit/test_expression_tree/test_scalar.py | 7 +++- .../test_expression_tree/test_state_vector.py | 4 +- .../test_expression_tree/test_variable.py | 2 +- 9 files changed, 57 insertions(+), 42 deletions(-) diff --git a/tests/unit/test_expression_tree/test_array.py b/tests/unit/test_expression_tree/test_array.py index 6ef1669270..885c5e0851 100644 --- a/tests/unit/test_expression_tree/test_array.py +++ b/tests/unit/test_expression_tree/test_array.py @@ -42,22 +42,27 @@ def test_to_equation(self): pybamm.Array([1, 2]).to_equation(), sympy.Array([[1.0], [2.0]]) ) - def test_to_json_array(self): + def test_to_from_json(self): arr = pybamm.Array(np.array([1, 2, 3])) - self.assertEqual( - arr.to_json(), - { - "name": "Array of shape (3, 1)", - "id": mock.ANY, # The value of the ID will change, but want to check it is present - "domains": { - "primary": [], - "secondary": [], - "tertiary": [], - "quaternary": [], - }, - "entries": [[1.0], [2.0], [3.0]], + + json_dict = { + "name": "Array of shape (3, 1)", + "id": mock.ANY, # The value of the ID will change, but want to check it is present + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], }, - ) + "entries": [[1.0], [2.0], [3.0]], + } + + # array to json conversion + created_json = arr.to_json() + self.assertEqual(created_json, json_dict) + + # json to array conversion + self.assertEqual(pybamm.Array._from_json(created_json), arr) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_broadcasts.py b/tests/unit/test_expression_tree/test_broadcasts.py index be6772af2d..b91cd7d95c 100644 --- a/tests/unit/test_expression_tree/test_broadcasts.py +++ b/tests/unit/test_expression_tree/test_broadcasts.py @@ -350,7 +350,7 @@ def test_diff(self): self.assertIsInstance(d, pybamm.Scalar) self.assertEqual(d.evaluate(y=y), 0) - def test_to_json(self): + def test_to_json_error(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.PrimaryBroadcast(a, "separator") with self.assertRaises(NotImplementedError): diff --git a/tests/unit/test_expression_tree/test_input_parameter.py b/tests/unit/test_expression_tree/test_input_parameter.py index 48ad2c441f..a5fc79f2e2 100644 --- a/tests/unit/test_expression_tree/test_input_parameter.py +++ b/tests/unit/test_expression_tree/test_input_parameter.py @@ -51,7 +51,7 @@ def test_errors(self): with self.assertRaises(KeyError): a.evaluate() - def test_to_json(self): + def test_to_from_json(self): a = pybamm.InputParameter("a") json_dict = { @@ -61,8 +61,12 @@ def test_to_json(self): "expected_size": 1, } + # to_json self.assertEqual(a.to_json(), json_dict) + # from_json + self.assertEqual(pybamm.InputParameter._from_json(json_dict), a) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index b6c195eccc..7389ff183a 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -325,7 +325,7 @@ def test_processing(self): self.assertEqual(interp, interp.new_copy()) - def test_to_json(self): + def test_to_json_error(self): x = np.linspace(0, 1, 200) y = pybamm.StateVector(slice(0, 2)) interp = pybamm.Interpolant(x, 2 * x, y) diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 8e466818f1..2c3d2379ab 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -40,27 +40,28 @@ def test_matrix_operations(self): (self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]]) ) - def test_to_json_matrix(self): + def test_to_from_json(self): arr = pybamm.Matrix(csr_matrix([[0, 1, 0, 0], [0, 0, 0, 1]])) - self.assertEqual( - arr.to_json(), - { - "name": "Sparse Matrix (2, 4)", - "id": mock.ANY, # The value of the ID will change, but want to check it is present - "domains": { - "primary": [], - "secondary": [], - "tertiary": [], - "quaternary": [], - }, - "entries": { - "column_pointers": [0, 1, 2], - "data": [1.0, 1.0], - "row_indices": [1, 3], - "shape": (2, 4), - }, + json_dict = { + "name": "Sparse Matrix (2, 4)", + "id": mock.ANY, # The value of the ID will change, but want to check it is present + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], }, - ) + "entries": { + "column_pointers": [0, 1, 2], + "data": [1.0, 1.0], + "row_indices": [1, 3], + "shape": (2, 4), + }, + } + + self.assertEqual(arr.to_json(), json_dict) + + self.assertEqual(pybamm.Matrix._from_json(json_dict), arr) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index 6001d8906b..62441f4309 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -31,7 +31,7 @@ def test_to_equation(self): # Test name self.assertEqual(func1.to_equation(), sympy.Symbol("test_name")) - def test_to_json(self): + def test_to_json_error(self): func = pybamm.Parameter("test_string") with self.assertRaises(NotImplementedError): @@ -115,7 +115,7 @@ def test_function_parameter_to_equation(self): func1.print_name = None self.assertEqual(func1.to_equation(), sympy.Symbol("func")) - def test_to_json(self): + def test_to_json_error(self): func = pybamm.FunctionParameter("test", {"x": pybamm.Scalar(1)}) with self.assertRaises(NotImplementedError): diff --git a/tests/unit/test_expression_tree/test_scalar.py b/tests/unit/test_expression_tree/test_scalar.py index 9d990e354d..34ea1aa514 100644 --- a/tests/unit/test_expression_tree/test_scalar.py +++ b/tests/unit/test_expression_tree/test_scalar.py @@ -45,10 +45,13 @@ def test_copy(self): b = a.create_copy() self.assertEqual(a, b) - def test_to_json(self): + def test_to_from_json(self): a = pybamm.Scalar(5) + json_dict = {"name": "5.0", "id": mock.ANY, "value": 5.0} - self.assertEqual(a.to_json(), {"name": "5.0", "id": mock.ANY, "value": 5.0}) + self.assertEqual(a.to_json(), json_dict) + + self.assertEqual(pybamm.Scalar._from_json(json_dict), a) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_state_vector.py b/tests/unit/test_expression_tree/test_state_vector.py index 0165d1d512..9897b9a027 100644 --- a/tests/unit/test_expression_tree/test_state_vector.py +++ b/tests/unit/test_expression_tree/test_state_vector.py @@ -63,7 +63,7 @@ def test_failure(self): with self.assertRaisesRegex(TypeError, "all y_slices must be slice objects"): pybamm.StateVector(slice(0, 10), 1) - def test_to_json(self): + def test_to_from_json(self): array = np.array([1, 2, 3, 4, 5]) sv = pybamm.StateVector(slice(0, 10), evaluation_array=array) @@ -88,6 +88,8 @@ def test_to_json(self): self.assertEqual(sv.to_json(), json_dict) + self.assertEqual(pybamm.StateVector._from_json(json_dict), sv) + class TestStateVectorDot(TestCase): def test_evaluate(self): diff --git a/tests/unit/test_expression_tree/test_variable.py b/tests/unit/test_expression_tree/test_variable.py index f4f3029c75..b350cb794d 100644 --- a/tests/unit/test_expression_tree/test_variable.py +++ b/tests/unit/test_expression_tree/test_variable.py @@ -63,7 +63,7 @@ def test_to_equation(self): # Test name self.assertEqual(pybamm.Variable("name").to_equation(), sympy.Symbol("name")) - def test_to_json(self): + def test_to_json_error(self): func = pybamm.Variable("test_string") with self.assertRaises(NotImplementedError): func.to_json() From fbc8f6fd682725dc30f39c818d281ffab907079c Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 22 Sep 2023 13:42:30 +0000 Subject: [PATCH 08/29] (wip) testing: add draft de/serialisation tests allow interpolant to be serialised fix concatenation with debug mode switch msmr warnings --- pybamm/expression_tree/concatenations.py | 4 +- pybamm/expression_tree/interpolant.py | 38 +++- .../full_battery_models/base_battery_model.py | 4 + .../full_battery_models/lithium_ion/msmr.py | 4 +- .../test_expression_tree/test_interpolant.py | 40 +++- .../test_expression_tree/test_state_vector.py | 6 + tests/unit/test_serialisation/__init__.py | 0 .../test_serialisation/test_serialisation.py | 182 ++++++++++++++++++ 8 files changed, 268 insertions(+), 10 deletions(-) create mode 100644 tests/unit/test_serialisation/__init__.py create mode 100644 tests/unit/test_serialisation/test_serialisation.py diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index af3db72846..d393cc6647 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -48,10 +48,10 @@ def _from_json(cls, *children, name, domains, concat_fun=None): # PL: update this one - I guess we still want it to take 'snippet' rather than the list? to be the same as the others? instance = cls.__new__(cls) - super(Concatenation, instance).__init__(name, children, domains=domains) - instance.concatenation_function = concat_fun + super(Concatenation, instance).__init__(name, children, domains=domains) + return instance def __str__(self): diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 5234e5e927..20d4e0180b 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -202,6 +202,27 @@ def __init__( self.interpolator = interpolator self.extrapolate = extrapolate + @classmethod + def _from_json(cls, snippet: dict): + """Create an Interpolant object from JSON data""" + instance = cls.__new__(cls) + + if len(snippet["x"]) == 1: + x = [np.array(x) for x in snippet["x"]] + else: + x = tuple(np.array(x) for x in snippet["x"]) + + instance.__init__( + x, + np.array(snippet["y"]), + snippet["children"], + name=snippet["name"], + interpolator=snippet["interpolator"], + extrapolate=snippet["extrapolate"], + ) + + return instance + @property def entries_string(self): return self._entries_string @@ -292,6 +313,17 @@ def _function_evaluate(self, evaluated_children): raise ValueError("Invalid dimension: {0}".format(self.dimension)) def to_json(self): - raise NotImplementedError( - "pybamm.Interpolant: Serialisation is only implemented for discretised models." - ) + """ + Method to serialise an Interpolant object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "x": [x_item.tolist() for x_item in self.x], + "y": self.y.tolist(), + "interpolator": self.interpolator, + "extrapolate": self.extrapolate, + } + + return json_dict diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index 841bf53a81..deb312b379 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -604,6 +604,10 @@ def __init__(self, extra_options): if option in ["working electrode"]: pass else: + # serialised options save tuples as lists which need to be converted + if isinstance(value, list) and len(value) == 2: + value = tuple(value) + if isinstance(value, str) or option in [ "dimensionality", "operating mode", diff --git a/pybamm/models/full_battery_models/lithium_ion/msmr.py b/pybamm/models/full_battery_models/lithium_ion/msmr.py index 3ca07c4ef8..f1ec7f90bd 100644 --- a/pybamm/models/full_battery_models/lithium_ion/msmr.py +++ b/pybamm/models/full_battery_models/lithium_ion/msmr.py @@ -19,7 +19,7 @@ def __init__(self, options=None, name="MSMR", build=True): options["open-circuit potential"] ) ) - elif "particle" in options and options["particle"] == "MSMR": + elif "particle" in options and options["particle"] != "MSMR": raise pybamm.OptionError( "'particle' must be 'MSMR' for MSMR not '{}'".format( options["particle"] @@ -27,7 +27,7 @@ def __init__(self, options=None, name="MSMR", build=True): ) elif ( "intercalation kinetics" in options - and options["intercalation kinetics"] == "MSMR" + and options["intercalation kinetics"] != "MSMR" ): raise pybamm.OptionError( "'intercalation kinetics' must be 'MSMR' for MSMR not '{}'".format( diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 7389ff183a..93009adf0d 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -5,6 +5,7 @@ import pybamm import unittest +import unittest.mock as mock import numpy as np @@ -326,12 +327,45 @@ def test_processing(self): self.assertEqual(interp, interp.new_copy()) def test_to_json_error(self): - x = np.linspace(0, 1, 200) + x = np.linspace(0, 1, 10) y = pybamm.StateVector(slice(0, 2)) interp = pybamm.Interpolant(x, 2 * x, y) - with self.assertRaises(NotImplementedError): - interp.to_json() + self.assertEqual( + interp.to_json(), + { + "name": "interpolating_function", + "id": mock.ANY, + "x": [ + [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0, + ] + ], + "y": [ + 0.0, + 0.2222222222222222, + 0.4444444444444444, + 0.6666666666666666, + 0.8888888888888888, + 1.1111111111111112, + 1.3333333333333333, + 1.5555555555555554, + 1.7777777777777777, + 2.0, + ], + "interpolator": "linear", + "extrapolate": True, + }, + ) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_state_vector.py b/tests/unit/test_expression_tree/test_state_vector.py index 9897b9a027..18025c0aa3 100644 --- a/tests/unit/test_expression_tree/test_state_vector.py +++ b/tests/unit/test_expression_tree/test_state_vector.py @@ -64,6 +64,9 @@ def test_failure(self): pybamm.StateVector(slice(0, 10), 1) def test_to_from_json(self): + original_debug_mode = pybamm.settings.debug_mode + pybamm.settings.debug_mode = False + array = np.array([1, 2, 3, 4, 5]) sv = pybamm.StateVector(slice(0, 10), evaluation_array=array) @@ -90,6 +93,9 @@ def test_to_from_json(self): self.assertEqual(pybamm.StateVector._from_json(json_dict), sv) + # Turn debug mode back to what is was before + pybamm.settings.debug_mode = original_debug_mode + class TestStateVectorDot(TestCase): def test_evaluate(self): diff --git a/tests/unit/test_serialisation/__init__.py b/tests/unit/test_serialisation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py new file mode 100644 index 0000000000..72d4a1a072 --- /dev/null +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -0,0 +1,182 @@ +# +# Tests for the serialisation class +# +from tests import TestCase +import pybamm + +pybamm.settings.debug_mode = True + +import numpy as np +import unittest + + +class TestSerialise(TestCase): + # test lithium models + def test_spm_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lithium_ion.SPM() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_equal(solution.all_ys[x], new_solution.all_ys[x]) + + def test_spme_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lithium_ion.SPMe() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_equal(solution.all_ys[x], new_solution.all_ys[x]) + + def test_mpm_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lithium_ion.MPM() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) + + def test_dfn_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lithium_ion.DFN() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) + + def test_newman_tobias_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lithium_ion.NewmanTobias() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) + + def test_msmr_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lithium_ion.MSMR({"number of MSMR reactions": ("6", "4")}) + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x], decimal=3 + ) + + # test lead-acid models + def test_lead_acid_full_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lead_acid.Full() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) + + def test_loqs_serialisation_recreation(self): + t = [0, 3600] + + model = pybamm.lead_acid.LOQS() + sim = pybamm.Simulation(model) + solution = sim.solve(t) + + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) + + # def test_thevenin_serialisation_recreation(self): + # t = [0, 3600] + + # model = pybamm.equivalent_circuit.Thevenin() + # sim = pybamm.Simulation(model) + # solution = sim.solve(t) + + # sim.save_model("test_model") + + # new_model = pybamm.load_model("test_model.json") + # new_solver = new_model.default_solver + # new_solution = new_solver.solve(new_model, t) + + # for x, val in enumerate(solution.all_ys): + # np.testing.assert_array_almost_equal( + # solution.all_ys[x], new_solution.all_ys[x] + # ) + + +if __name__ == "__main__": + print("Add -v for more debug output") + import sys + + if "-v" in sys.argv: + debug = True + pybamm.settings.debug_mode = True + unittest.main() From 4745484f7b00943c8a0ff521817bde64fccc0a90 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 22 Sep 2023 16:43:21 +0000 Subject: [PATCH 09/29] (wip) tests: add _from_json tests with children allow BaseModel to run without rhs --- pybamm/expression_tree/binary_operators.py | 24 +++-- pybamm/expression_tree/unary_operators.py | 49 +++++++++- pybamm/models/base_model.py | 81 ++++++++++++++-- .../full_battery_models/base_battery_model.py | 5 +- pybamm/solvers/base_solver.py | 11 ++- .../test_binary_operators.py | 95 +++++++++++-------- .../test_concatenations.py | 43 ++++++--- .../unit/test_expression_tree/test_symbol.py | 29 +++--- .../test_unary_operators.py | 70 +++++++------- .../test_serialisation/test_serialisation.py | 34 ++++--- tests/unit/test_simulation.py | 20 ++++ 11 files changed, 324 insertions(+), 137 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 30a81ee416..56f3154be9 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -571,14 +571,6 @@ def __init__(self, name, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json( - snippet["name"], snippet["children"][0], snippet["children"][1] - ) - return instance - def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Heaviside should always be multiplied by something else so hopefully don't @@ -610,6 +602,14 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("<=", left, right) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = cls.__new__(cls) + + instance.__init__(snippet["children"][0], snippet["children"][1]) + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{!s} <= {!s}".format(self.left, self.right) @@ -627,6 +627,14 @@ class NotEqualHeaviside(_Heaviside): def __init__(self, left, right): super().__init__("<", left, right) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.BinaryOperator._from_json()`.""" + instance = cls.__new__(cls) + + instance.__init__(snippet["children"][0], snippet["children"][1]) + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{!s} < {!s}".format(self.left, self.right) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index b4db6b6528..2b85309469 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -150,6 +150,12 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("abs", child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = super()._from_json("abs", snippet) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return sign(self.child) * self.child.diff(variable) @@ -176,6 +182,12 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("sign", child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = super()._from_json("sign", snippet) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return pybamm.Scalar(0) @@ -206,6 +218,12 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("floor", child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = super()._from_json("floor", snippet) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return pybamm.Scalar(0) @@ -228,6 +246,12 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("ceil", child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = super()._from_json("ceil", snippet) + return instance + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return pybamm.Scalar(0) @@ -293,6 +317,25 @@ def __init__(self, child, index, name=None, check_size=True): if isinstance(index, int): self.clear_domains() + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = cls.__new__(cls) + + index = slice( + snippet["index"]["start"], + snippet["index"]["stop"], + snippet["index"]["step"], + ) + + instance.__init__( + snippet["children"][0], + index, + name=snippet["name"], + check_size=snippet["check_size"], + ) + return instance + def _unary_jac(self, child_jac): """See :meth:`pybamm.UnaryOperator._unary_jac()`.""" @@ -345,7 +388,11 @@ def to_json(self): json_dict = { "name": self.name, "id": self.id, - "domains": self.domains, + "index": { + "start": self.slice.start, + "stop": self.slice.stop, + "step": self.slice.step, + }, "check_size": False, } diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 80dda2808d..21648f3dfc 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -10,6 +10,7 @@ import pybamm from pybamm.expression_tree.operations.latexify import Latexify +from pybamm.expression_tree.operations.serialise import Serialise class BaseModel: @@ -123,11 +124,65 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None + # PL: Next up, how to pass in the non-standard variables, if necessary. @classmethod - def deserialise(cls, properties: dict): - raise NotImplementedError( - "BaseModel: Serialisation not yet implemented for non-battery models." - ) + def deserialise( + cls, properties: dict + ): # PL: maybe option up here as output_mesh=true to output a tuple, (model, mesh) rather than just updating the variables and leaving it at that. + """ + Create a model instance from a serialised object. + """ + instance = cls.__new__(cls) + + # append the model name with _saved to differentiate + instance.__init__(name=properties["name"] + "_saved") + + # PL: what to do with the options? + + # Initialise model with stored variables that have already been discretised + instance._concatenated_rhs = properties["concatenated_rhs"] + instance._concatenated_algebraic = properties["concatenated_algebraic"] + instance._concatenated_initial_conditions = properties[ + "concatenated_initial_conditions" + ] + + instance.len_rhs = instance.concatenated_rhs.size + instance.len_alg = instance.concatenated_algebraic.size + instance.len_rhs_and_alg = instance.len_rhs + instance.len_alg + + instance.bounds = properties["bounds"] + instance.events = properties["events"] + instance.mass_matrix = properties["mass_matrix"] + instance.mass_matrix_inv = properties["mass_matrix_inv"] + + # add optional properties not required for model to solve + if properties["variables"]: + instance._variables = pybamm.FuzzyDict(properties["variables"]) + + # assign meshes to each variable + for var in instance._variables.values(): + if var.domain != []: + var.mesh = properties["mesh"][var.domain] + else: + var.mesh = None + + if var.domains["secondary"] != []: + var.secondary_mesh = properties["mesh"][var.domains["secondary"]] + else: + var.secondary_mesh = None + + instance._geometry = pybamm.Geometry(properties["geometry"]) + else: + # Delete the default variables which have not been discretised + instance._variables = pybamm.FuzzyDict({}) + + # PL: Simulation(new_model, new_mesh) + # doesn't work because the model is already discretised, you can't give it a new mesh. + + # Model has already been discretised + instance.is_discretised = True + + return instance @property def name(self): @@ -1117,9 +1172,21 @@ def process_parameters_and_discretise(self, symbol, parameter_values, disc): return disc_symbol def save_model(self, filename=None, mesh=None, variables=None): - raise NotImplementedError( - "BaseModel: Serialisation not yet implemented for non-battery models." - ) + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. + """ + if variables and not mesh: + raise ValueError( + "Serialisation: Please provide the mesh if variables are required" + ) + + Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) # helper functions for finding symbols diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index deb312b379..ac7d16f3ed 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -6,6 +6,8 @@ from functools import cached_property import warnings +from pybamm.expression_tree.operations.serialise import Serialise + def represents_positive_integer(s): """Check if a string represents a positive integer""" @@ -17,9 +19,6 @@ def represents_positive_integer(s): return val > 0 -from pybamm.expression_tree.operations.serialise import Serialise - - class BatteryModelOptions(pybamm.FuzzyDict): """ Attributes diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 7740006310..13f8a22f34 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -707,9 +707,14 @@ def solve( # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: if not isinstance(self, pybamm.DummySolver): - raise pybamm.ModelError( - "Cannot solve empty model, use `pybamm.DummySolver` instead" - ) + # check a discretised model without original paramaters is not being used + if not ( + model.concatenated_rhs is not None + or model.concatenated_algebraic is not None + ): + raise pybamm.ModelError( + "Cannot solve empty model, use `pybamm.DummySolver` instead" + ) # t_eval can only be None if the solver is an algebraic solver. In that case # set it to 0 diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 9a66e3a639..18f654566b 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -780,54 +780,69 @@ def test_to_equation(self): def test_to_json(self): # Test Addition - self.assertEqual( - pybamm.Addition(2, 4).to_json(), - { - "name": "+", - "id": mock.ANY, - "domains": EMPTY_DOMAINS, - }, - ) + add_json = { + "name": "+", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + add = pybamm.Addition(2, 4) + + self.assertEqual(add.to_json(), add_json) + + add_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] + self.assertEqual(pybamm.Addition._from_json(add_json), add) # Test Power - self.assertEqual( - pybamm.Power(7, 2).to_json(), - { - "name": "**", - "id": mock.ANY, - "domains": EMPTY_DOMAINS, - }, - ) + pow_json = { + "name": "**", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + pow = pybamm.Power(7, 2) + self.assertEqual(pow.to_json(), pow_json) + + pow_json["children"] = [pybamm.Scalar(7), pybamm.Scalar(2)] + self.assertEqual(pybamm.Power._from_json(pow_json), pow) # Test Division - self.assertEqual( - pybamm.Division(10, 5).to_json(), - { - "name": "/", - "id": mock.ANY, - "domains": EMPTY_DOMAINS, - }, - ) + div_json = { + "name": "/", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + div = pybamm.Division(10, 5) + self.assertEqual(div.to_json(), div_json) + + div_json["children"] = [pybamm.Scalar(10), pybamm.Scalar(5)] + self.assertEqual(pybamm.Division._from_json(div_json), div) # Test EqualHeaviside - self.assertEqual( - pybamm.EqualHeaviside(2, 4).to_json(), - { - "name": "<=", - "id": mock.ANY, - "domains": EMPTY_DOMAINS, - }, - ) + equal_json = { + "name": "<=", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + equal_h = pybamm.EqualHeaviside(2, 4) + self.assertEqual(equal_h.to_json(), equal_json) + + equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] + self.assertEqual(pybamm.EqualHeaviside._from_json(equal_json), equal_h) # Test notEqualHeaviside - self.assertEqual( - pybamm.NotEqualHeaviside(2, 4).to_json(), - { - "name": "<", - "id": mock.ANY, - "domains": EMPTY_DOMAINS, - }, - ) + not_equal_json = { + "name": "<", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + ne_h = pybamm.NotEqualHeaviside(2, 4) + self.assertEqual(ne_h.to_json(), not_equal_json) + + not_equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] + self.assertEqual(pybamm.NotEqualHeaviside._from_json(not_equal_json), ne_h) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index f846220a77..2da745158a 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -383,7 +383,7 @@ def test_to_equation(self): # Test concat_sym self.assertEqual(pybamm.Concatenation(a, b).to_equation(), func_symbol) - def test_to_json(self): + def test_to_from_json(self): # test DomainConcatenation mesh = get_mesh_for_testing() a = pybamm.Symbol("a", domain=["negative electrode"]) @@ -420,26 +420,41 @@ def test_to_json(self): json_dict, ) - # test NumpyConcatenation + # manually add children + json_dict["children"] = [a, b] + + # check symbol re-creation + self.assertEqual(pybamm.pybamm.DomainConcatenation._from_json(json_dict), conc) + + # ----------------------------- + # test NumpyConcatenation ----- + # ----------------------------- + y = np.linspace(0, 1, 15)[:, np.newaxis] a_np = pybamm.Vector(y[:5]) b_np = pybamm.Vector(y[5:9]) c_np = pybamm.Vector(y[9:]) conc_np = pybamm.NumpyConcatenation(a_np, b_np, c_np) - self.assertEqual( - conc_np.to_json(), - { - "name": "numpy_concatenation", - "id": mock.ANY, - "domains": { - "primary": [], - "secondary": [], - "tertiary": [], - "quaternary": [], - }, + np_json = { + "name": "numpy_concatenation", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], }, - ) + } + + # test to_json + self.assertEqual(conc_np.to_json(), np_json) + + # add children + np_json["children"] = [a_np, b_np, c_np] + + # test _from_json + self.assertEqual(pybamm.NumpyConcatenation._from_json(np_json), conc_np) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 17c5f0a02f..a2cea1801e 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -487,24 +487,27 @@ def test_numpy_array_ufunc(self): x = pybamm.Symbol("x") self.assertEqual(np.exp(x), pybamm.exp(x)) - def test_to_json(self): + def test_to_from_json(self): symc1 = pybamm.Symbol("child1", domain=["domain_1"]) symc2 = pybamm.Symbol("child2", domain=["domain_2"]) symp = pybamm.Symbol("parent", domain=["domain_3"], children=[symc1, symc2]) - self.assertEqual( - symp.to_json(), - { - "name": "parent", - "id": mock.ANY, - "domains": { - "primary": ["domain_3"], - "secondary": [], - "tertiary": [], - "quaternary": [], - }, + json_dict = { + "name": "parent", + "id": mock.ANY, + "domains": { + "primary": ["domain_3"], + "secondary": [], + "tertiary": [], + "quaternary": [], }, - ) + } + + self.assertEqual(symp.to_json(), json_dict) + + json_dict["children"] = [symc1, symc2] + + self.assertEqual(pybamm.Symbol._from_json(json_dict), symp) class TestIsZero(TestCase): diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index e8fc7c7be0..3c9de976d6 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -669,41 +669,42 @@ def test_explicit_time_integral(self): self.assertEqual(expr.new_copy(), expr) self.assertFalse(expr.is_constant()) - def test_to_json(self): + def test_to_from_json(self): # UnaryOperator a = pybamm.Symbol("a", domain=["test"]) un = pybamm.UnaryOperator("unary test", a) - self.assertEqual( - un.to_json(), - { - "name": "unary test", - "id": mock.ANY, - "domains": { - "primary": ["test"], - "secondary": [], - "tertiary": [], - "quaternary": [], - }, + + un_json = { + "name": "unary test", + "id": mock.ANY, + "domains": { + "primary": ["test"], + "secondary": [], + "tertiary": [], + "quaternary": [], }, - ) + } + + self.assertEqual(un.to_json(), un_json) + + un_json["children"] = [a] + self.assertEqual(pybamm.UnaryOperator._from_json("unary test", un_json), un) # Index vec = pybamm.StateVector(slice(0, 5)) ind = pybamm.Index(vec, 3) - self.assertEqual( - ind.to_json(), - { - "name": "Index[3]", - "id": mock.ANY, - "domains": { - "primary": [], - "secondary": [], - "tertiary": [], - "quaternary": [], - }, - "check_size": False, - }, - ) + + ind_json = { + "name": "Index[3]", + "id": mock.ANY, + "index": {"start": 3, "stop": 4, "step": None}, + "check_size": False, + } + + self.assertEqual(ind.to_json(), ind_json) + + ind_json["children"] = [vec] + self.assertEqual(pybamm.Index._from_json(ind_json), ind) # SpatialOperator spatial_vec = pybamm.SpatialOperator("name", vec) @@ -712,13 +713,14 @@ def test_to_json(self): # ExplicitTimeIntegral expr = pybamm.ExplicitTimeIntegral(pybamm.Parameter("param"), pybamm.Scalar(1)) - self.assertEqual( - expr.to_json(), - { - "name": "explicit time integral", - "id": mock.ANY, - }, - ) + + expr_json = {"name": "explicit time integral", "id": mock.ANY} + + self.assertEqual(expr.to_json(), expr_json) + + expr_json["children"] = [pybamm.Parameter("param")] + expr_json["initial_condition"] = [pybamm.Scalar(1)] + self.assertEqual(pybamm.ExplicitTimeIntegral._from_json(expr_json), expr) if __name__ == "__main__": diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 72d4a1a072..268a4082eb 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -2,6 +2,7 @@ # Tests for the serialisation class # from tests import TestCase +import tests import pybamm pybamm.settings.debug_mode = True @@ -10,7 +11,7 @@ import unittest -class TestSerialise(TestCase): +class TestSerialiseModels(TestCase): # test lithium models def test_spm_serialisation_recreation(self): t = [0, 3600] @@ -153,23 +154,28 @@ def test_loqs_serialisation_recreation(self): solution.all_ys[x], new_solution.all_ys[x] ) - # def test_thevenin_serialisation_recreation(self): - # t = [0, 3600] + def test_thevenin_serialisation_recreation(self): + t = [0, 3600] - # model = pybamm.equivalent_circuit.Thevenin() - # sim = pybamm.Simulation(model) - # solution = sim.solve(t) + model = pybamm.equivalent_circuit.Thevenin() + sim = pybamm.Simulation(model) + solution = sim.solve(t) - # sim.save_model("test_model") + sim.save_model("test_model") + + new_model = pybamm.load_model("test_model.json") + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) - # new_model = pybamm.load_model("test_model.json") - # new_solver = new_model.default_solver - # new_solution = new_solver.solve(new_model, t) - # for x, val in enumerate(solution.all_ys): - # np.testing.assert_array_almost_equal( - # solution.all_ys[x], new_solution.all_ys[x] - # ) +class TestSerialiseExpressionTree(TestCase): + def test_tree_walk(self): + pass if __name__ == "__main__": diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py index d0926e5c94..e20d2e0460 100644 --- a/tests/unit/test_simulation.py +++ b/tests/unit/test_simulation.py @@ -327,6 +327,26 @@ def test_save_load_dae(self): sim_load = pybamm.load_sim("test.pickle") self.assertEqual(sim.model.name, sim_load.model.name) + def test_save_load_model(self): + model = pybamm.lead_acid.LOQS({"surface form": "algebraic"}) + model.use_jacobian = True + sim = pybamm.Simulation(model) + + # test exception if not discretised + with self.assertRaises(NotImplementedError): + sim.save_model("sim_save") + + # save after solving + sim.solve([0, 600]) + sim.save_model("sim_save") + + # load model + saved_model = pybamm.load_model("sim_save.json") + + self.assertEqual(model.options, saved_model.options) + + os.remove("sim_save.json") + def test_plot(self): sim = pybamm.Simulation(pybamm.lithium_ion.SPM()) From 80fc250e3c6e25b1e1ff00ba1a5ffec123ccd0e5 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Wed, 27 Sep 2023 16:02:42 +0000 Subject: [PATCH 10/29] testing: add unit tests for Serialise() functions Add to_from_json test for Events --- .../expression_tree/operations/serialise.py | 7 +- pybamm/expression_tree/symbol.py | 1 - tests/unit/test_models/test_event.py | 22 ++ .../test_serialisation/test_serialisation.py | 364 +++++++++++++++++- 4 files changed, 387 insertions(+), 7 deletions(-) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index d27f770451..2f79f0f6f7 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -207,7 +207,7 @@ def load_model( } recon_model_dict["geometry"] = ( - self._reconstruct_geometry(model_data["geometry"]) + self._reconstruct_pybamm_dict(model_data["geometry"]) if "geometry" in model_data.keys() else None ) @@ -255,7 +255,8 @@ def _get_pybamm_class(self, snippet: dict): try: empty_class = self._Empty() empty_class.__class__ = class_ - except: + except TypeError: + # Mesh objects have a different layouts empty_class = self._EmptyDict() empty_class.__class__ = class_ @@ -345,7 +346,7 @@ def _reconstruct_mesh(self, node: dict): return new_mesh - def _reconstruct_geometry(self, obj: dict): + def _reconstruct_pybamm_dict(self, obj: dict): """ pybamm.Geometry can contain PyBaMM symbols as dictionary keys. diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index b0747090cd..dfa4a05bf0 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -1015,7 +1015,6 @@ def to_json(self): "name": self.name, "id": self.id, "domains": self.domains, - # "children": self.children, # the encoder deals with the children itself. } return json_dict diff --git a/tests/unit/test_models/test_event.py b/tests/unit/test_models/test_event.py index 7d0d00f000..84b0dcde84 100644 --- a/tests/unit/test_models/test_event.py +++ b/tests/unit/test_models/test_event.py @@ -48,6 +48,28 @@ def test_event_types(self): event = pybamm.Event("my event", pybamm.Scalar(1), event_type) self.assertEqual(event.event_type, event_type) + def test_to_from_json(self): + expression = pybamm.Scalar(1) + event = pybamm.Event("my event", expression) + + event_json = { + "name": "my event", + "event_type": ["EventType.TERMINATION", 0], + } + + event_ser_json = event.to_json() + self.assertEqual(event_ser_json, event_json) + + event_json["expression"] = expression + + new_event = pybamm.Event._from_json(event_json) + + # check for equal expressions + self.assertEqual(new_event.expression, event.expression) + + # check for equal event types + self.assertEqual(new_event.event_type, event.event_type) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 268a4082eb..de9dfc868d 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -4,11 +4,87 @@ from tests import TestCase import tests import pybamm +from pybamm.expression_tree.operations.serialise import Serialise pybamm.settings.debug_mode = True import numpy as np import unittest +import unittest.mock as mock + +from numpy import testing + + +def scalar_var_dict(): + """variable, json pair for a pybamm.Scalar instance""" + a = pybamm.Scalar(5) + a_dict = { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scalar", + "name": "5.0", + "id": mock.ANY, + "value": 5.0, + "children": [], + } + + return a, a_dict + + +def mesh_var_dict(): + """mesh, json pair for a pybamm.Mesh instance""" + + r = pybamm.SpatialVariable( + "r", domain=["negative particle"], coord_sys="spherical polar" + ) + + geometry = { + "negative particle": {r: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}} + } + + submesh_types = {"negative particle": pybamm.Uniform1DSubMesh} + var_pts = {r: 20} + + # create mesh + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + mesh_json = { + "py/object": "pybamm.meshes.meshes.Mesh", + "py/id": mock.ANY, + "submesh_pts": {"negative particle": {"r": 20}}, + "base_domains": ["negative particle"], + "sub_meshes": { + "negative particle": { + "py/object": "pybamm.meshes.one_dimensional_submeshes.Uniform1DSubMesh", + "py/id": mock.ANY, + "edges": [ + 0.0, + 0.05, + 0.1, + 0.15000000000000002, + 0.2, + 0.25, + 0.30000000000000004, + 0.35000000000000003, + 0.4, + 0.45, + 0.5, + 0.55, + 0.6000000000000001, + 0.65, + 0.7000000000000001, + 0.75, + 0.8, + 0.8500000000000001, + 0.9, + 0.9500000000000001, + 1.0, + ], + "coord_sys": "spherical polar", + } + }, + } + + return mesh, mesh_json class TestSerialiseModels(TestCase): @@ -173,9 +249,291 @@ def test_thevenin_serialisation_recreation(self): ) -class TestSerialiseExpressionTree(TestCase): - def test_tree_walk(self): - pass +class TestSerialise(TestCase): + # test the symbol encoder + + def test_symbol_encoder_symbol(self): + """test basic symbol encoder with & without children""" + + # without children + a, a_dict = scalar_var_dict() + + a_ser_json = Serialise._SymbolEncoder().default(a) + + self.assertEqual(a_ser_json, a_dict) + + # with children + add = pybamm.Addition(2, 4) + add_json = { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.binary_operators.Addition", + "name": "+", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [ + { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scalar", + "name": "2.0", + "id": mock.ANY, + "value": 2.0, + "children": [], + }, + { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scalar", + "name": "4.0", + "id": mock.ANY, + "value": 4.0, + "children": [], + }, + ], + } + + add_ser_json = Serialise._SymbolEncoder().default(add) + + self.assertEqual(add_ser_json, add_json) + + def test_symbol_encoder_explicitTimeIntegral(self): + """test symbol encoder with initial conditions""" + expr = pybamm.ExplicitTimeIntegral(pybamm.Scalar(5), pybamm.Scalar(1)) + + expr_json = { + "py/object": "pybamm.expression_tree.unary_operators.ExplicitTimeIntegral", + "py/id": mock.ANY, + "name": "explicit time integral", + "id": mock.ANY, + "children": [ + { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": mock.ANY, + "name": "5.0", + "id": mock.ANY, + "value": 5.0, + "children": [], + } + ], + "initial_condition": { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": mock.ANY, + "name": "1.0", + "id": mock.ANY, + "value": 1.0, + "children": [], + }, + } + + expr_ser_json = Serialise._SymbolEncoder().default(expr) + + self.assertEqual(expr_json, expr_ser_json) + + def test_symbol_encoder_event(self): + """test symbol encoder with event""" + + expression = pybamm.Scalar(1) + event = pybamm.Event("my event", expression) + + event_json = { + "py/object": "pybamm.models.event.Event", + "py/id": mock.ANY, + "name": "my event", + "event_type": ["EventType.TERMINATION", 0], + "expression": { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": mock.ANY, + "name": "1.0", + "id": mock.ANY, + "value": 1.0, + "children": [], + }, + } + + event_ser_json = Serialise._SymbolEncoder().default(event) + self.assertEqual(event_ser_json, event_json) + + # test the mesh encoder + def test_mesh_encoder(self): + mesh, mesh_json = mesh_var_dict() + + # serialise mesh + mesh_ser_json = Serialise._MeshEncoder().default(mesh) + + self.assertEqual(mesh_ser_json, mesh_json) + + def test_deconstruct_pybamm_dicts(self): + """tests serialisation of dictionaries with pybamm classes as keys""" + + x = pybamm.SpatialVariable("x", "negative electrode") + + test_dict = {"rod": {x: {"min": 0.0, "max": 2.0}}} + + ser_dict = { + "rod": { + "symbol_x": { + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", + "py/id": mock.ANY, + "name": "x", + "id": mock.ANY, + "domains": { + "primary": ["negative electrode"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [], + }, + "x": {"min": 0.0, "max": 2.0}, + } + } + + self.assertEqual(Serialise()._deconstruct_pybamm_dicts(test_dict), ser_dict) + + def test_get_pybamm_class(self): + # symbol + _, scalar_dict = scalar_var_dict() + + scalar_class = Serialise()._get_pybamm_class(scalar_dict) + + self.assertIsInstance(scalar_class, pybamm.Scalar) + + # mesh + _, mesh_dict = mesh_var_dict() + + mesh_class = Serialise()._get_pybamm_class(mesh_dict) + + self.assertIsInstance(mesh_class, pybamm.Mesh) + + def test_reconstruct_symbol(self): + scalar, scalar_dict = scalar_var_dict() + + new_scalar = Serialise()._reconstruct_symbol(scalar_dict) + + self.assertEqual(new_scalar, scalar) + + def test_reconstruct_expression_tree(self): + y = pybamm.StateVector(slice(0, 1)) + t = pybamm.t + equation = 2 * y + t + + equation_json = { + "py/object": "pybamm.expression_tree.binary_operators.Addition", + "py/id": 139691619709376, + "name": "+", + "id": -2595875552397011963, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [ + { + "py/object": "pybamm.expression_tree.binary_operators.Multiplication", + "py/id": 139691619709232, + "name": "*", + "id": 6094209803352873499, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [ + { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": 139691619709040, + "name": "2.0", + "id": 1254626814648295285, + "value": 2.0, + "children": [], + }, + { + "py/object": "pybamm.expression_tree.state_vector.StateVector", + "py/id": 139691619589760, + "name": "y[0:1]", + "id": 5063056989669636089, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "y_slice": [{"start": 0, "stop": 1, "step": None}], + "evaluation_array": [True], + "children": [], + }, + ], + }, + { + "py/object": "pybamm.expression_tree.independent_variable.Time", + "py/id": 139692083289392, + "name": "time", + "id": -3301344124754766351, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [], + }, + ], + } + + new_equation = Serialise()._reconstruct_expression_tree(equation_json) + + self.assertEqual(new_equation, equation) + + def test_reconstruct_mesh(self): + mesh, mesh_dict = mesh_var_dict() + + new_mesh = Serialise()._reconstruct_mesh(mesh_dict) + + testing.assert_array_equal( + new_mesh["negative particle"].edges, mesh["negative particle"].edges + ) + testing.assert_array_equal( + new_mesh["negative particle"].nodes, mesh["negative particle"].nodes + ) + + # reconstructed meshes are only used for plotting, geometry not reconstructed. + with self.assertRaisesRegex( + AttributeError, "'Mesh' object has no attribute '_geometry'" + ): + self.assertEqual(new_mesh.geometry, mesh.geometry) + + def test_reconstruct_pybamm_dict(self): + x = pybamm.SpatialVariable("x", "negative electrode") + + test_dict = {"rod": {x: {"min": 0.0, "max": 2.0}}} + + ser_dict = { + "rod": { + "symbol_x": { + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", + "py/id": mock.ANY, + "name": "x", + "id": mock.ANY, + "domains": { + "primary": ["negative electrode"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [], + }, + "x": {"min": 0.0, "max": 2.0}, + } + } + + new_dict = Serialise()._reconstruct_pybamm_dict(ser_dict) + + self.assertEqual(new_dict, test_dict) if __name__ == "__main__": From ac928ab0fbeb52578e7e0903e8010cddf155f122 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 29 Sep 2023 09:00:57 +0000 Subject: [PATCH 11/29] testing: save/load model tests --- .../expression_tree/operations/serialise.py | 11 +- pybamm/models/base_model.py | 3 +- .../test_serialisation/test_serialisation.py | 117 +++++++++++++++++- 3 files changed, 117 insertions(+), 14 deletions(-) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 2f79f0f6f7..edba97a3c4 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from pybamm import BaseBatteryModel + from pybamm import BaseModel class Serialise: @@ -136,9 +136,8 @@ def save_model( model_json["mesh"] = self._MeshEncoder().default(mesh) if variables: - model_json["geometry"] = self._deconstruct_pybamm_dicts( - dict(model._geometry) - ) + if model._geometry: + model_json["geometry"] = self._deconstruct_pybamm_dicts(model._geometry) model_json["variables"] = { k: self._SymbolEncoder().default(v) for k, v in dict(variables).items() } @@ -149,9 +148,7 @@ def save_model( with open(filename + ".json", "w") as f: json.dump(model_json, f) - def load_model( - self, filename: str, battery_model: BaseBatteryModel = None - ) -> BaseBatteryModel: + def load_model(self, filename: str, battery_model: BaseModel = None) -> BaseModel: """ Loads a discretised, ready to solve model into PyBaMM. diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 21648f3dfc..11d8661be9 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -171,7 +171,8 @@ def deserialise( else: var.secondary_mesh = None - instance._geometry = pybamm.Geometry(properties["geometry"]) + if properties["geometry"]: + instance._geometry = pybamm.Geometry(properties["geometry"]) else: # Delete the default variables which have not been discretised instance._variables = pybamm.FuzzyDict({}) diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index de9dfc868d..d65f95fe93 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -3,16 +3,15 @@ # from tests import TestCase import tests -import pybamm -from pybamm.expression_tree.operations.serialise import Serialise - -pybamm.settings.debug_mode = True - -import numpy as np +import os import unittest import unittest.mock as mock +from datetime import datetime +import numpy as np +import pybamm from numpy import testing +from pybamm.expression_tree.operations.serialise import Serialise def scalar_var_dict(): @@ -535,6 +534,112 @@ def test_reconstruct_pybamm_dict(self): self.assertEqual(new_dict, test_dict) + def test_save_load_model(self): + model = pybamm.lithium_ion.SPM(name="test_spm") + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) + + # test error if not discretised + with self.assertRaisesRegex( + NotImplementedError, + "PyBaMM can only serialise a discretised, ready-to-solve model", + ): + Serialise().save_model(model, filename="test_model") + + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + + # default save + Serialise().save_model(model, filename="test_model") + self.assertTrue(os.path.exists("test_model.json")) + + # default save where filename isn't provided + Serialise().save_model(model) + filename = ( + "test_spm_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M_%S") + ".json" + ) + self.assertTrue(os.path.exists(filename)) + os.remove(filename) + + # default load + new_model = Serialise().load_model("test_model.json") + + # check new model solves + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, [0, 3600]) + + # check an error is raised when plotting the solution + with self.assertRaisesRegex( + AttributeError, + "Variables not provided by the serialised model", + ): + new_solution.plot() + + # load when specifying the battery model to use + newest_model = Serialise().load_model( + "test_model.json", battery_model=pybamm.lithium_ion.SPM + ) + os.remove("test_model.json") + + # check new model solves + newest_solver = newest_model.default_solver + newest_solution = newest_solver.solve(newest_model, [0, 3600]) + + def test_serialised_model_plotting(self): + # models without a mesh + model = pybamm.BaseModel() + c = pybamm.Variable("c") + model.rhs = {c: -c} + model.initial_conditions = {c: 1} + model.variables["c"] = c + model.variables["2c"] = 2 * c + + # setup and discretise + _ = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) + + Serialise().save_model( + model, + variables=model.variables, + filename="test_base_model", + ) + + new_model = Serialise().load_model("test_base_model.json") + os.remove("test_base_model.json") + + new_solution = pybamm.ScipySolver().solve(new_model, np.linspace(0, 1)) + + # check dynamic plot loads + new_solution.plot(["c", "2c"], testing=True) + + # models with a mesh ---------------- + model = pybamm.lithium_ion.SPM(name="test_spm_plotting") + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + + Serialise().save_model( + model, + variables=model.variables, + mesh=mesh, + filename="test_plotting_model", + ) + + new_model = Serialise().load_model("test_plotting_model.json") + os.remove("test_plotting_model.json") + + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, [0, 3600]) + + # check dynamic plot loads + new_solution.plot(testing=True) + if __name__ == "__main__": print("Add -v for more debug output") From 9e323d986653bc1374d0537110581b4e86013847 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 29 Sep 2023 16:23:36 +0000 Subject: [PATCH 12/29] testing: Add integration tests add to_json tests for meshes All but 3 int. tests passing, high accuracy diff failures --- .../notebooks/models/saving_models.ipynb | 42 +++- pybamm/__init__.py | 2 +- pybamm/expression_tree/array.py | 4 - pybamm/expression_tree/concatenations.py | 2 +- pybamm/expression_tree/functions.py | 4 +- pybamm/meshes/one_dimensional_submeshes.py | 5 +- pybamm/meshes/scikit_fem_submeshes.py | 24 +++ pybamm/models/base_model.py | 35 +++- .../full_battery_models/base_battery_model.py | 26 +-- .../test_models/standard_model_tests.py | 38 ++++ tests/unit/test_meshes/test_meshes.py | 24 +++ .../test_one_dimensional_submesh.py | 26 +++ .../test_meshes/test_scikit_fem_submesh.py | 38 ++++ tests/unit/test_models/test_base_model.py | 27 +++ .../test_serialisation/test_serialisation.py | 190 ++++-------------- 15 files changed, 291 insertions(+), 196 deletions(-) diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb index c3c9c90ea8..f2826813d4 100644 --- a/docs/source/examples/notebooks/models/saving_models.ipynb +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -56,7 +56,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -68,7 +68,7 @@ "# Recreate the pybamm model from the JSON file\n", "new_dfn_model = pybamm.load_model(\"sim_model_example.json\")\n", "\n", - "sim_reloaded = pybamm.Simulation(new_dfn_model) # PL: will this work if anything other than the default options are used? I guess not...\n", + "sim_reloaded = pybamm.Simulation(new_dfn_model)\n", "sim_reloaded.solve([0, 3600])" ] }, @@ -122,13 +122,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "eaf8ae8b8dd84a99b8b1aecfc132ad83", + "model_id": "b21266c9388043dbbe06e6d93dda3009", "version_major": 2, "version_minor": 0 }, @@ -142,10 +142,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -277,6 +277,36 @@ "new_spm_solution.plot()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Making edits to a serialised model\n", + "\n", + "As mentioned at the begining of this notebook, only models which have already been discretised can be serialised and readh back in. This means that after serialisation, the model *cannot be edited*, as the non-discretised elements of the model such as the original rhs are not saved.\n", + "\n", + "If you are likely to want to save a model and then edit it in the future, you may wish to use the `Simulation.save()` functionality to pickle your simulation, as described in [tutorial 6](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/getting_started/tutorial-6-managing-simulation-outputs.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before finishing we will remove the data files we saved so that we leave the directory as we found it" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.remove(\"example_model.json\")\n", + "os.remove(\"sim_model_example.json\")\n", + "os.remove(\"sim_model_variables.json\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/pybamm/__init__.py b/pybamm/__init__.py index c6376ec0b3..d3eab5cc56 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -191,7 +191,7 @@ # # Serialisation # -from .models.full_battery_models.base_battery_model import load_model +from .models.base_model import load_model # # Spatial Methods diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 0fc74f3209..6807cf9f94 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -166,10 +166,6 @@ def to_json(self): "row_indices": self.entries.indices.tolist(), "column_pointers": self.entries.indptr.tolist(), } - else: - raise TypeError( - f"Ah! Dense matrix! {self.entries}" - ) # PL: Double check this json_dict = { "name": self.name, diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index d393cc6647..7ebbb0c59a 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -45,7 +45,7 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None): @classmethod def _from_json(cls, *children, name, domains, concat_fun=None): - # PL: update this one - I guess we still want it to take 'snippet' rather than the list? to be the same as the others? + """Creates a new Concatenation instance from a json object""" instance = cls.__new__(cls) instance.concatenation_function = concat_fun diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 9743e7c754..a6c47f1650 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -212,9 +212,7 @@ def to_equation(self): eq_list.append(eq) return self._sympy_operator(*eq_list) - def to_json( - self, - ): # PL: I think these ones might actually be present when you build your own function. + def to_json(self): raise NotImplementedError( "pybamm.Function: Serialisation is only implemented for discretised models." ) diff --git a/pybamm/meshes/one_dimensional_submeshes.py b/pybamm/meshes/one_dimensional_submeshes.py index 147ed590cf..d68745daec 100644 --- a/pybamm/meshes/one_dimensional_submeshes.py +++ b/pybamm/meshes/one_dimensional_submeshes.py @@ -34,7 +34,7 @@ def __init__(self, edges, coord_sys, tabs=None): self.internal_boundaries = [] # Add tab locations in terms of "left" and "right" - if tabs: + if tabs and "negative tab" not in tabs.keys(): self.tabs = {} l_z = self.edges[-1] @@ -52,6 +52,9 @@ def near(x, point, tol=3e-16): f"{tab} tab located at {tab_location}, " f"but must be at either 0 or {l_z}" ) + elif tabs: + # tabs have already been calculated by a serialised model + self.tabs = tabs def read_lims(self, lims): # Separate limits and tabs diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index f25dce80b1..ca86bcc11f 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -34,6 +34,9 @@ def __init__(self, edges, coord_sys, tabs): self.npts = len(self.edges["y"]) * len(self.edges["z"]) self.coord_sys = coord_sys + # save tabs for serialisation + self.tabs = tabs + # create mesh self.fem_mesh = skfem.MeshTri.init_tensor(self.edges["y"], self.edges["z"]) @@ -141,6 +144,15 @@ def between(x, interval, tol=3e-16): else: raise pybamm.GeometryError("tab location not valid") + def to_json(self): + json_dict = { + "edges": {k: v.tolist() for k, v in self.edges.items()}, + "coord_sys": self.coord_sys, + "tabs": self.tabs, + } + + return json_dict + class ScikitUniform2DSubMesh(ScikitSubMesh2D): """ @@ -177,6 +189,18 @@ def __init__(self, lims, npts): super().__init__(edges, coord_sys, tabs) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + edges = {k: np.array(v) for k, v in snippet["edges"].items()} + + super(ScikitUniform2DSubMesh, instance).__init__( + edges, snippet["coord_sys"], snippet["tabs"] + ) + + return instance + class ScikitExponential2DSubMesh(ScikitSubMesh2D): """ diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 11d8661be9..46645f992d 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -2,6 +2,7 @@ # Base model class # import numbers +import warnings from collections import OrderedDict import copy @@ -124,11 +125,8 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None - # PL: Next up, how to pass in the non-standard variables, if necessary. @classmethod - def deserialise( - cls, properties: dict - ): # PL: maybe option up here as output_mesh=true to output a tuple, (model, mesh) rather than just updating the variables and leaving it at that. + def deserialise(cls, properties: dict): """ Create a model instance from a serialised object. """ @@ -137,7 +135,7 @@ def deserialise( # append the model name with _saved to differentiate instance.__init__(name=properties["name"] + "_saved") - # PL: what to do with the options? + instance.options = properties["options"] # Initialise model with stored variables that have already been discretised instance._concatenated_rhs = properties["concatenated_rhs"] @@ -177,9 +175,6 @@ def deserialise( # Delete the default variables which have not been discretised instance._variables = pybamm.FuzzyDict({}) - # PL: Simulation(new_model, new_mesh) - # doesn't work because the model is already discretised, you can't give it a new mesh. - # Model has already been discretised instance.is_discretised = True @@ -1183,13 +1178,33 @@ def save_model(self, filename=None, mesh=None, variables=None): based on the model name, and the current datetime. """ if variables and not mesh: - raise ValueError( - "Serialisation: Please provide the mesh if variables are required" + warnings.warn( + """ + Serialisation: Variables are being saved without a mesh. + Plotting may not be available. + """, + pybamm.ModelWarning, ) Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) +def load_model(filename, battery_model: BaseModel = None): + """ + Load in a saved model from a JSON file + + Parameters + ---------- + filename: str + Path to the JSON file containing the serialised model file + battery_model: :class: pybamm.BaseBatteryModel, optional + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override + any model names within the file. If None, the function will look for the saved object + path, present if the original model came from PyBaMM. + """ + return Serialise().load_model(filename, battery_model) + + # helper functions for finding symbols def find_symbol_in_tree(tree, name): if name == tree.name: diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index ac7d16f3ed..5046ba801f 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -605,7 +605,7 @@ def __init__(self, extra_options): else: # serialised options save tuples as lists which need to be converted if isinstance(value, list) and len(value) == 2: - value = tuple(value) + value = tuple(tuple(v) if len(v) == 2 else v for v in value) if isinstance(value, str) or option in [ "dimensionality", @@ -805,11 +805,8 @@ def __init__(self, options=None, name="Unnamed battery model"): super().__init__(name) self.options = options - # PL: Next up, how to pass in the non-standard variables, if necessary. @classmethod - def deserialise( - cls, properties: dict - ): # PL: maybe option up here as output_mesh=true to output a tuple, (model, mesh) rather than just updating the variables and leaving it at that. + def deserialise(cls, properties: dict): """ Create a model instance from a serialised object. """ @@ -857,9 +854,6 @@ def deserialise( # Delete the default variables which have not been discretised instance._variables = pybamm.FuzzyDict({}) - # PL: Simulation(new_model, new_mesh) - # doesn't work because the model is already discretised, you can't give it a new mesh. - # Model has already been discretised instance.is_discretised = True @@ -1462,19 +1456,3 @@ def save_model(self, filename=None, mesh=None, variables=None): ) Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) - - -def load_model(filename, battery_model: BaseBatteryModel = None): - """ - Load in a saved model from a JSON file - - Parameters - ---------- - filename: str - Path to the JSON file containing the serialised model file - battery_model: :class: pybamm.BaseBatteryModel, optional - PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override - any model names within the file. If None, the function will look for the saved object - path, present if the original model came from PyBaMM. - """ - return Serialise().load_model(filename, battery_model) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 9341122d84..0526363a75 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -5,6 +5,7 @@ import tests import numpy as np +import os class StandardModelTest(object): @@ -138,6 +139,42 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") atol=1e-6, ) + def test_serialisation(self, solver=None, t_eval=None): + self.model.save_model( + "test_model", variables=self.model.variables, mesh=self.disc.mesh + ) + + new_model = pybamm.load_model("test_model.json") + + # create new solver for re-created model + if solver is not None: + new_solver = solver + else: + new_solver = new_model.default_solver + + if isinstance(new_model, pybamm.lithium_ion.BaseModel): + new_solver.rtol = 1e-8 + new_solver.atol = 1e-8 + + Crate = abs( + self.parameter_values["Current function [A]"] + / self.parameter_values["Nominal cell capacity [A.h]"] + ) + # don't allow zero C-rate + if Crate == 0: + Crate = 1 + if t_eval is None: + t_eval = np.linspace(0, 3600 / Crate, 100) + + new_solution = new_solver.solve(new_model, t_eval) + + for x, val in enumerate(self.solution.all_ys): + np.testing.assert_array_almost_equal( + new_solution.all_ys[x], self.solution.all_ys[x] + ) + + os.remove("test_model.json") + def test_all( self, param=None, disc=None, solver=None, t_eval=None, skip_output_tests=False ): @@ -145,6 +182,7 @@ def test_all( self.test_processing_parameters(param) self.test_processing_disc(disc) self.test_solving(solver, t_eval) + self.test_serialisation(solver, t_eval) if ( isinstance( diff --git a/tests/unit/test_meshes/test_meshes.py b/tests/unit/test_meshes/test_meshes.py index 6563ba232d..000ec729a5 100644 --- a/tests/unit/test_meshes/test_meshes.py +++ b/tests/unit/test_meshes/test_meshes.py @@ -390,6 +390,30 @@ def test_1plus1D_tabs_right_left(self): # positive tab should be "left" self.assertEqual(mesh["current collector"].tabs["positive tab"], "left") + def test_to_from_json(self): + r = pybamm.SpatialVariable( + "r", domain=["negative particle"], coord_sys="spherical polar" + ) + + geometry = { + "negative particle": {r: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}} + } + + submesh_types = {"negative particle": pybamm.Uniform1DSubMesh} + var_pts = {r: 20} + + # create mesh + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + mesh_json = mesh.to_json() + + expected_json = { + "submesh_pts": {"negative particle": {"r": 20}}, + "base_domains": ["negative particle"], + } + + self.assertEqual(mesh_json, expected_json) + class TestMeshGenerator(TestCase): def test_init_name(self): diff --git a/tests/unit/test_meshes/test_one_dimensional_submesh.py b/tests/unit/test_meshes/test_one_dimensional_submesh.py index 207f5f2b6f..a7cafb5e25 100644 --- a/tests/unit/test_meshes/test_one_dimensional_submesh.py +++ b/tests/unit/test_meshes/test_one_dimensional_submesh.py @@ -18,6 +18,32 @@ def test_exceptions(self): with self.assertRaises(pybamm.GeometryError): pybamm.SubMesh1D(edges, None, tabs=tabs) + def test_to_json(self): + edges = np.linspace(0, 1, 10) + tabs = {"negative": {"z_centre": 0}, "positive": {"z_centre": 1}} + mesh = pybamm.SubMesh1D(edges, None, tabs=tabs) + + mesh_json = mesh.to_json() + + expected_json = { + "edges": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0, + ], + "coord_sys": None, + "tabs": {"negative tab": "left", "positive tab": "right"}, + } + + self.assertEqual(mesh_json, expected_json) + class TestUniform1DSubMesh(TestCase): def test_exceptions(self): diff --git a/tests/unit/test_meshes/test_scikit_fem_submesh.py b/tests/unit/test_meshes/test_scikit_fem_submesh.py index 2e646e1085..88bde7941f 100644 --- a/tests/unit/test_meshes/test_scikit_fem_submesh.py +++ b/tests/unit/test_meshes/test_scikit_fem_submesh.py @@ -180,6 +180,44 @@ def test_tab_left_right(self): param.process_geometry(geometry) pybamm.Mesh(geometry, submesh_types, var_pts) + def test_to_json(self): + param = get_param() + geometry = pybamm.battery_geometry( + include_particles=False, options={"dimensionality": 2} + ) + param.process_geometry(geometry) + + var_pts = {"x_n": 10, "x_s": 7, "x_p": 12, "y": 16, "z": 24} + + submesh_types = { + "negative electrode": pybamm.Uniform1DSubMesh, + "separator": pybamm.Uniform1DSubMesh, + "positive electrode": pybamm.Uniform1DSubMesh, + "current collector": pybamm.MeshGenerator(pybamm.ScikitUniform2DSubMesh), + } + + # create mesh + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + mesh_json = mesh.to_json() + + expected_json = { + "submesh_pts": { + "negative electrode": {"x_n": 10}, + "separator": {"x_s": 7}, + "positive electrode": {"x_p": 12}, + "current collector": {"y": 16, "z": 24}, + }, + "base_domains": [ + "negative electrode", + "separator", + "positive electrode", + "current collector", + ], + } + + self.assertEqual(mesh_json, expected_json) + class TestScikitFiniteElementChebyshev2DSubMesh(TestCase): def test_mesh_creation(self): diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 4167d5fff5..1274d1a7bf 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -9,6 +9,7 @@ import casadi import numpy as np +from numpy import testing import pybamm @@ -982,6 +983,32 @@ def test_timescale_lengthscale_get_set_not_implemented(self): with self.assertRaises(NotImplementedError): model.length_scales = 1 + def test_save_load_model(self): + model = pybamm.BaseModel() + c = pybamm.Variable("c") + model.rhs = {c: -c} + model.initial_conditions = {c: 1} + model.variables["c"] = c + model.variables["2c"] = 2 * c + + # setup and discretise + solution = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) + + # save model + model.save_model(filename="test_base_model") + + # raises warning if variables are saved + with self.assertWarns(pybamm.ModelWarning): + model.save_model(filename="test_base_model", variables=model.variables) + + new_model = pybamm.load_model("test_base_model.json") + + new_solution = pybamm.ScipySolver().solve(new_model, np.linspace(0, 1)) + + # model solutions match + testing.assert_array_equal(solution.all_ys, new_solution.all_ys) + os.remove("test_base_model.json") + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index d65f95fe93..53924b8c2b 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -87,165 +87,63 @@ def mesh_var_dict(): class TestSerialiseModels(TestCase): - # test lithium models - def test_spm_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lithium_ion.SPM() - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_equal(solution.all_ys[x], new_solution.all_ys[x]) - - def test_spme_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lithium_ion.SPMe() - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_equal(solution.all_ys[x], new_solution.all_ys[x]) - - def test_mpm_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lithium_ion.MPM() - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_almost_equal( - solution.all_ys[x], new_solution.all_ys[x] - ) - - def test_dfn_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lithium_ion.DFN() - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_almost_equal( - solution.all_ys[x], new_solution.all_ys[x] - ) - - def test_newman_tobias_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lithium_ion.NewmanTobias() - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_almost_equal( - solution.all_ys[x], new_solution.all_ys[x] - ) - - def test_msmr_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lithium_ion.MSMR({"number of MSMR reactions": ("6", "4")}) - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_almost_equal( - solution.all_ys[x], new_solution.all_ys[x], decimal=3 - ) - - # test lead-acid models - def test_lead_acid_full_serialisation_recreation(self): - t = [0, 3600] - - model = pybamm.lead_acid.Full() - sim = pybamm.Simulation(model) - solution = sim.solve(t) - - sim.save_model("test_model") - - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_almost_equal( - solution.all_ys[x], new_solution.all_ys[x] - ) - - def test_loqs_serialisation_recreation(self): - t = [0, 3600] + def test_user_defined_model_recreaction(self): + # Start with a base model + model = pybamm.BaseModel() - model = pybamm.lead_acid.LOQS() - sim = pybamm.Simulation(model) - solution = sim.solve(t) + # Define the variables and parameters + x = pybamm.SpatialVariable("x", domain="rod", coord_sys="cartesian") + T = pybamm.Variable("Temperature", domain="rod") + k = pybamm.Parameter("Thermal diffusivity") + + # Write the governing equations + N = -k * pybamm.grad(T) # Heat flux + Q = 1 - pybamm.Function(np.abs, x - 1) # Source term + dTdt = -pybamm.div(N) + Q + model.rhs = {T: dTdt} # add to model + + # Add the boundary and initial conditions + model.boundary_conditions = { + T: { + "left": (pybamm.Scalar(0), "Dirichlet"), + "right": (pybamm.Scalar(0), "Dirichlet"), + } + } + model.initial_conditions = {T: 2 * x - x**2} - sim.save_model("test_model") + # Add desired output variables, geometry, parameters + model.variables = {"Temperature": T, "Heat flux": N, "Heat source": Q} + geometry = {"rod": {x: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(2)}}} + param = pybamm.ParameterValues({"Thermal diffusivity": 0.75}) - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver - new_solution = new_solver.solve(new_model, t) - - for x, val in enumerate(solution.all_ys): - np.testing.assert_array_almost_equal( - solution.all_ys[x], new_solution.all_ys[x] - ) + # Process model and geometry + param.process_model(model) + param.process_geometry(geometry) - def test_thevenin_serialisation_recreation(self): - t = [0, 3600] + # Pick mesh, spatial method, and discretise + submesh_types = {"rod": pybamm.Uniform1DSubMesh} + var_pts = {x: 30} + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + spatial_methods = {"rod": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) - model = pybamm.equivalent_circuit.Thevenin() - sim = pybamm.Simulation(model) - solution = sim.solve(t) + # Solve + solver = pybamm.ScipySolver() + t = np.linspace(0, 1, 100) + solution = solver.solve(model, t) - sim.save_model("test_model") + model.save_model("heat_equation", variables=model._variables, mesh=mesh) + new_model = pybamm.load_model("heat_equation.json") - new_model = pybamm.load_model("test_model.json") - new_solver = new_model.default_solver + new_solver = pybamm.ScipySolver() new_solution = new_solver.solve(new_model, t) for x, val in enumerate(solution.all_ys): np.testing.assert_array_almost_equal( solution.all_ys[x], new_solution.all_ys[x] ) + os.remove("heat_equation.json") class TestSerialise(TestCase): From 2934df462cb5f11912ae22aca0f76b435ae75d7d Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Mon, 2 Oct 2023 10:55:16 +0000 Subject: [PATCH 13/29] Add docs for serialisation update integration test to pass at lower accuracy Remove outputs from example notebook --- .../api/expression_tree/operations/index.rst | 1 + .../expression_tree/operations/serialise.rst | 5 + docs/source/examples/index.rst | 1 + .../notebooks/models/saving_models.ipynb | 127 ++---------------- .../expression_tree/operations/serialise.py | 26 ++-- .../test_models/standard_model_tests.py | 7 +- 6 files changed, 36 insertions(+), 131 deletions(-) create mode 100644 docs/source/api/expression_tree/operations/serialise.rst diff --git a/docs/source/api/expression_tree/operations/index.rst b/docs/source/api/expression_tree/operations/index.rst index c084389f1a..67beaca136 100644 --- a/docs/source/api/expression_tree/operations/index.rst +++ b/docs/source/api/expression_tree/operations/index.rst @@ -8,4 +8,5 @@ Classes and functions that operate on the expression tree evaluate jacobian convert_to_casadi + serialise unpack_symbol diff --git a/docs/source/api/expression_tree/operations/serialise.rst b/docs/source/api/expression_tree/operations/serialise.rst new file mode 100644 index 0000000000..daa1b652f1 --- /dev/null +++ b/docs/source/api/expression_tree/operations/serialise.rst @@ -0,0 +1,5 @@ +Serialise +========= + +.. autoclass:: pybamm.expression_tree.operations.serialise.Serialise + :members: diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 4bab430032..36b0d3d81f 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -62,6 +62,7 @@ The notebooks are organised into subfolders, and can be viewed in the galleries notebooks/models/MSMR.ipynb notebooks/models/pouch-cell-model.ipynb notebooks/models/rate-capability.ipynb + notebooks/models/saving_models.ipynb notebooks/models/SEI-on-cracks.ipynb notebooks/models/simulating-ORegan-2022-parameter-set.ipynb notebooks/models/SPM.ipynb diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb index f2826813d4..85ca516a59 100644 --- a/docs/source/examples/notebooks/models/saving_models.ipynb +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -18,17 +18,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], + "outputs": [], "source": [ "%pip install pybamm -q # install PyBaMM if it is not installed\n", "import pybamm\n", @@ -50,20 +42,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Recreate the pybamm model from the JSON file\n", "new_dfn_model = pybamm.load_model(\"sim_model_example.json\")\n", @@ -83,23 +64,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "Variables not provided by the serialised model", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/pliggins/PyBaMM/docs/source/examples/notebooks/models/saving_models.ipynb Cell 7\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m plot_sim\u001b[39m.\u001b[39msolve([\u001b[39m0\u001b[39m, \u001b[39m3600\u001b[39m])\n\u001b[1;32m 6\u001b[0m sims\u001b[39m.\u001b[39mappend(plot_sim)\n\u001b[0;32m----> 8\u001b[0m pybamm\u001b[39m.\u001b[39;49mdynamic_plot(sims, time_unit\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mseconds\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", - "File \u001b[0;32m~/PyBaMM/pybamm/plotting/dynamic_plot.py:20\u001b[0m, in \u001b[0;36mdynamic_plot\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[39mCreates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[39marguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39m The 'QuickPlot' object that was created\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 19\u001b[0m kwargs_for_class \u001b[39m=\u001b[39m {k: v \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m kwargs\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m k \u001b[39m!=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[0;32m---> 20\u001b[0m plot \u001b[39m=\u001b[39m pybamm\u001b[39m.\u001b[39;49mQuickPlot(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs_for_class)\n\u001b[1;32m 21\u001b[0m plot\u001b[39m.\u001b[39mdynamic_plot(kwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m))\n\u001b[1;32m 22\u001b[0m \u001b[39mreturn\u001b[39;00m plot\n", - "File \u001b[0;32m~/PyBaMM/pybamm/plotting/quick_plot.py:159\u001b[0m, in \u001b[0;36mQuickPlot.__init__\u001b[0;34m(self, solutions, output_variables, labels, colors, linestyles, shading, figsize, n_rows, time_unit, spatial_unit, variable_limits)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[39m# check variables have been provided after any serialisation\u001b[39;00m\n\u001b[1;32m 158\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(\u001b[39mlen\u001b[39m(m\u001b[39m.\u001b[39mvariables) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mfor\u001b[39;00m m \u001b[39min\u001b[39;00m models):\n\u001b[0;32m--> 159\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mVariables not provided by the serialised model\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 161\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows \u001b[39m=\u001b[39m n_rows \u001b[39mor\u001b[39;00m \u001b[39mint\u001b[39m(\n\u001b[1;32m 162\u001b[0m \u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m np\u001b[39m.\u001b[39msqrt(\u001b[39mlen\u001b[39m(output_variables))\n\u001b[1;32m 163\u001b[0m )\n\u001b[1;32m 164\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_cols \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(np\u001b[39m.\u001b[39mceil(\u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows))\n", - "\u001b[0;31mAttributeError\u001b[0m: Variables not provided by the serialised model" - ] - } - ], + "outputs": [], "source": [ "dfn_models = [dfn_model, new_dfn_model]\n", "sims = []\n", @@ -122,34 +89,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b21266c9388043dbbe06e6d93dda3009", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "interactive(children=(FloatSlider(value=0.0, description='t', max=3600.0, step=36.0), Output()), _dom_classes=…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# using the first simulation, save a new file which includes a list of the model variables\n", "dfn_sim.save_model(\"sim_model_variables\", variables=True)\n", @@ -183,20 +125,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# create the model\n", "spm_model = pybamm.lithium_ion.SPM()\n", @@ -237,34 +168,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a1c0b22c969b45858361b7e9de264e76", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# read back in\n", "new_spm_model = pybamm.load_model(\"example_model.json\")\n", @@ -318,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -337,13 +243,6 @@ "source": [ "pybamm.print_citations()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index edba97a3c4..53e1f357d0 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -7,11 +7,6 @@ import numpy as np import re -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pybamm import BaseModel - class Serialise: """ @@ -84,28 +79,27 @@ class _EmptyDict(dict): def save_model( self, - model: pybamm.BaseBatteryModel, + model: pybamm.BaseModel, mesh: pybamm.Mesh = None, variables: pybamm.FuzzyDict = None, filename: str = None, ): - """ - Saves a discretised model to a JSON file. + """Saves a discretised model to a JSON file. As the model is discretised and ready to solve, only the right hand side, algebraic and initial condition variables are saved. Parameters ---------- - model: : :class:`pybamm.BaseBatteryModel` + model : :class:`pybamm.BaseModel` The discretised model to be saved - mesh: :class: `pybamm.Mesh`, optional + mesh : :class:`pybamm.Mesh` (optional) The mesh the model has been discretised over. Not neccesary to solve the model when read in, but required to use pybamm's plotting tools. - variables: :class: `pybamm.FuzzyDict`, optional + variables: :class:`pybamm.FuzzyDict` (optional) The discretised model varaibles. Not necessary to solve a model, but required to use pybamm's plotting tools. - filename: str, optional + filename: str (optional) The desired name of the JSON file. If no name is provided, one will be created based on the model name, and the current datetime. """ @@ -148,7 +142,9 @@ def save_model( with open(filename + ".json", "w") as f: json.dump(model_json, f) - def load_model(self, filename: str, battery_model: BaseModel = None) -> BaseModel: + def load_model( + self, filename: str, battery_model: pybamm.BaseModel = None + ) -> pybamm.BaseModel: """ Loads a discretised, ready to solve model into PyBaMM. @@ -166,14 +162,14 @@ def load_model(self, filename: str, battery_model: BaseModel = None) -> BaseMode filename: str Path to the JSON file containing the serialised model file - battery_model: :class: pybamm.BaseBatteryModel, optional + battery_model: :class:`pybamm.BaseModel` (optional) PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override any model names within the file. If None, the function will look for the saved object path, present if the original model came from PyBaMM. Returns ------- - :class: pybamm.BaseBatteryModel + :class:`pybamm.BaseModel` A PyBaMM model object, of type specified either in the JSON or in `battery_model`. """ diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 0526363a75..d0e38501c9 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -155,6 +155,9 @@ def test_serialisation(self, solver=None, t_eval=None): if isinstance(new_model, pybamm.lithium_ion.BaseModel): new_solver.rtol = 1e-8 new_solver.atol = 1e-8 + accuracy = 6 + else: + accuracy = 5 Crate = abs( self.parameter_values["Current function [A]"] @@ -170,7 +173,7 @@ def test_serialisation(self, solver=None, t_eval=None): for x, val in enumerate(self.solution.all_ys): np.testing.assert_array_almost_equal( - new_solution.all_ys[x], self.solution.all_ys[x] + new_solution.all_ys[x], self.solution.all_ys[x], decimal=accuracy ) os.remove("test_model.json") @@ -182,7 +185,6 @@ def test_all( self.test_processing_parameters(param) self.test_processing_disc(disc) self.test_solving(solver, t_eval) - self.test_serialisation(solver, t_eval) if ( isinstance( @@ -190,6 +192,7 @@ def test_all( ) and not skip_output_tests ): + self.test_serialisation(solver, t_eval) self.test_outputs() From 66d8045b29a26fc595ff82e386b5c89c1a082357 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Mon, 2 Oct 2023 14:53:17 +0000 Subject: [PATCH 14/29] Increase test coverage --- pybamm/expression_tree/binary_operators.py | 87 +------------ .../expression_tree/operations/serialise.py | 12 +- pybamm/expression_tree/unary_operators.py | 9 +- .../test_expression_tree/test_broadcasts.py | 5 +- .../test_expression_tree/test_functions.py | 115 ++++++++++++++++++ .../test_expression_tree/test_interpolant.py | 65 +++++----- .../test_expression_tree/test_parameter.py | 6 + .../test_unary_operators.py | 64 ++++++++++ .../test_serialisation/test_serialisation.py | 28 ++++- 9 files changed, 262 insertions(+), 129 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 56f3154be9..5bff7419d0 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -69,14 +69,14 @@ def __init__(self, name, left, right): self.right = self.children[1] @classmethod - def _from_json(cls, name, snippet: dict): + def _from_json(cls, snippet: dict): """Use to instantiate when deserialising; discretisation has already occured so pre-processing of binaries is not necessary.""" instance = cls.__new__(cls) super(BinaryOperator, instance).__init__( - name, + snippet["name"], children=[snippet["children"][0], snippet["children"][1]], domains=snippet["domains"], ) @@ -191,12 +191,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("**", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -238,12 +232,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("+", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) @@ -267,12 +255,6 @@ def __init__(self, left, right): super().__init__("-", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("-", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -298,12 +280,6 @@ def __init__(self, left, right): super().__init__("*", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("*", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -340,13 +316,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - # instance = super(MatrixMultiplication, cls)._from_json("@", left, right) - instance = super()._from_json("@", snippet) - return instance - def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # We shouldn't need this @@ -394,12 +363,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("/", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule @@ -444,12 +407,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("inner product", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -519,12 +476,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("==", snippet) - return instance - def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Equality should always be multiplied by something else so hopefully don't @@ -602,14 +553,6 @@ def __init__(self, left, right): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("<=", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = cls.__new__(cls) - - instance.__init__(snippet["children"][0], snippet["children"][1]) - return instance - def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{!s} <= {!s}".format(self.left, self.right) @@ -627,14 +570,6 @@ class NotEqualHeaviside(_Heaviside): def __init__(self, left, right): super().__init__("<", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = cls.__new__(cls) - - instance.__init__(snippet["children"][0], snippet["children"][1]) - return instance - def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{!s} < {!s}".format(self.left, self.right) @@ -652,12 +587,6 @@ class Modulo(BinaryOperator): def __init__(self, left, right): super().__init__("%", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("%", snippet) - return instance - def _diff(self, variable): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -696,12 +625,6 @@ class Minimum(BinaryOperator): def __init__(self, left, right): super().__init__("minimum", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("minimum", snippet) - return instance - def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "minimum({!s}, {!s})".format(self.left, self.right) @@ -738,12 +661,6 @@ class Maximum(BinaryOperator): def __init__(self, left, right): super().__init__("maximum", left, right) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.BinaryOperator._from_json()`.""" - instance = super()._from_json("maximum", snippet) - return instance - def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "maximum({!s}, {!s})".format(self.left, self.right) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 53e1f357d0..aa84db631e 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -40,9 +40,8 @@ def default(self, node: dict): node_dict["expression"] = self.default(node._expression) return node_dict - json_obj = json.JSONEncoder.default(self, node) # pragma: no cover - node_dict["json"] = json_obj - return node_dict + node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover + return node_dict # pragma: no cover class _MeshEncoder(json.JSONEncoder): """Converts PyBaMM meshes into a JSON-serialisable format""" @@ -63,9 +62,8 @@ def default(self, node: dict): node_dict.update(node.to_json()) return node_dict - json_obj = json.JSONEncoder.default(self, node) # pragma: no cover - node_dict["json"] = json_obj - return node_dict + node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover + return node_dict # pragma: no cover class _Empty: """A dummy class to aid deserialisation""" @@ -137,7 +135,7 @@ def save_model( } if filename is None: - filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M_%S") + filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M") with open(filename + ".json", "w") as f: json.dump(model_json, f) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 2b85309469..cc2b2a434e 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -184,9 +184,7 @@ def __init__(self, child): @classmethod def _from_json(cls, snippet: dict): - """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = super()._from_json("sign", snippet) - return instance + raise NotImplementedError() def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" @@ -421,6 +419,11 @@ class with a :class:`Matrix` def __init__(self, name, child, domains=None): super().__init__(name, child, domains) + def diff(self, variable): + """See :meth:`pybamm.Symbol.diff()`.""" + # We shouldn't need this + raise NotImplementedError + def to_json(self): raise NotImplementedError( "pybamm.SpatialOperator: Serialisation is only implemented for discretised models." diff --git a/tests/unit/test_expression_tree/test_broadcasts.py b/tests/unit/test_expression_tree/test_broadcasts.py index b91cd7d95c..be8fe1a677 100644 --- a/tests/unit/test_expression_tree/test_broadcasts.py +++ b/tests/unit/test_expression_tree/test_broadcasts.py @@ -350,12 +350,15 @@ def test_diff(self): self.assertIsInstance(d, pybamm.Scalar) self.assertEqual(d.evaluate(y=y), 0) - def test_to_json_error(self): + def test_to_from_json_error(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.PrimaryBroadcast(a, "separator") with self.assertRaises(NotImplementedError): b.to_json() + with self.assertRaises(NotImplementedError): + pybamm.PrimaryBroadcast._from_json({}) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index ac5410d9e1..07bfa7efe8 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import numpy as np import sympy @@ -145,8 +146,30 @@ def test_to_equation(self): # Test Function self.assertEqual(pybamm.Function(np.log, 10).to_equation(), 10.0) + def test_to_from_json_error(self): + a = pybamm.Symbol("a") + funca = pybamm.Function(test_function, a) + + with self.assertRaises(NotImplementedError): + funca.to_json() + + with self.assertRaises(NotImplementedError): + pybamm.Function._from_json({}) + class TestSpecificFunctions(TestCase): + def test_to_json(self): + a = pybamm.InputParameter("a") + fun = pybamm.cos(a) + + expected_json = { + "name": "function (cos)", + "id": mock.ANY, + "function": "cos", + } + + self.assertEqual(fun.to_json(), expected_json) + def test_arcsinh(self): a = pybamm.InputParameter("a") fun = pybamm.arcsinh(a) @@ -180,6 +203,15 @@ def test_arcsinh(self): pybamm.PrimaryBroadcast(pybamm.PrimaryBroadcast(fun, "test"), "test2"), ) + # test creation from json + input_json = { + "name": "arcsinh", + "id": mock.ANY, + "function": "arcsinh", + "children": [a], + } + self.assertEqual(pybamm.Arcsinh._from_json(input_json), fun) + def test_arctan(self): a = pybamm.InputParameter("a") fun = pybamm.arctan(a) @@ -196,6 +228,15 @@ def test_arctan(self): places=5, ) + # test creation from json + input_json = { + "name": "arctan", + "id": mock.ANY, + "function": "arctan", + "children": [a], + } + self.assertEqual(pybamm.Arctan._from_json(input_json), fun) + def test_cos(self): a = pybamm.InputParameter("a") fun = pybamm.cos(a) @@ -213,6 +254,15 @@ def test_cos(self): places=5, ) + # test creation from json + input_json = { + "name": "cos", + "id": mock.ANY, + "function": "cos", + "children": [a], + } + self.assertEqual(pybamm.Cos._from_json(input_json), fun) + def test_cosh(self): a = pybamm.InputParameter("a") fun = pybamm.cosh(a) @@ -230,6 +280,15 @@ def test_cosh(self): places=5, ) + # test creation from json + input_json = { + "name": "cosh", + "id": mock.ANY, + "function": "cosh", + "children": [a], + } + self.assertEqual(pybamm.Cosh._from_json(input_json), fun) + def test_exp(self): a = pybamm.InputParameter("a") fun = pybamm.exp(a) @@ -247,6 +306,15 @@ def test_exp(self): places=5, ) + # test creation from json + input_json = { + "name": "exp", + "id": mock.ANY, + "function": "exp", + "children": [a], + } + self.assertEqual(pybamm.Exp._from_json(input_json), fun) + def test_log(self): a = pybamm.InputParameter("a") fun = pybamm.log(a) @@ -276,6 +344,17 @@ def test_log(self): places=5, ) + # test creation from json + a = pybamm.InputParameter("a") + fun = pybamm.log(a) + input_json = { + "name": "log", + "id": mock.ANY, + "function": "log", + "children": [a], + } + self.assertEqual(pybamm.Log._from_json(input_json), fun) + def test_max(self): a = pybamm.StateVector(slice(0, 3)) y_test = np.array([1, 2, 3]) @@ -307,6 +386,15 @@ def test_sin(self): places=5, ) + # test creation from json + input_json = { + "name": "sin", + "id": mock.ANY, + "function": "sin", + "children": [a], + } + self.assertEqual(pybamm.Sin._from_json(input_json), fun) + def test_sinh(self): a = pybamm.InputParameter("a") fun = pybamm.sinh(a) @@ -324,6 +412,15 @@ def test_sinh(self): places=5, ) + # test creation from json + input_json = { + "name": "sinh", + "id": mock.ANY, + "function": "sinh", + "children": [a], + } + self.assertEqual(pybamm.Sinh._from_json(input_json), fun) + def test_sqrt(self): a = pybamm.InputParameter("a") fun = pybamm.sqrt(a) @@ -340,6 +437,15 @@ def test_sqrt(self): places=5, ) + # test creation from json + input_json = { + "name": "sqrt", + "id": mock.ANY, + "function": "sqrt", + "children": [a], + } + self.assertEqual(pybamm.Sqrt._from_json(input_json), fun) + def test_tanh(self): a = pybamm.InputParameter("a") fun = pybamm.tanh(a) @@ -370,6 +476,15 @@ def test_erf(self): places=5, ) + # test creation from json + input_json = { + "name": "erf", + "id": mock.ANY, + "function": "erf", + "children": [a], + } + self.assertEqual(pybamm.Erf._from_json(input_json), fun) + def test_erfc(self): a = pybamm.InputParameter("a") fun = pybamm.erfc(a) diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 93009adf0d..0b5ca5f64a 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -331,41 +331,46 @@ def test_to_json_error(self): y = pybamm.StateVector(slice(0, 2)) interp = pybamm.Interpolant(x, 2 * x, y) - self.assertEqual( - interp.to_json(), - { - "name": "interpolating_function", - "id": mock.ANY, - "x": [ - [ - 0.0, - 0.1111111111111111, - 0.2222222222222222, - 0.3333333333333333, - 0.4444444444444444, - 0.5555555555555556, - 0.6666666666666666, - 0.7777777777777777, - 0.8888888888888888, - 1.0, - ] - ], - "y": [ + print(interp.children) + expected_json = { + "name": "interpolating_function", + "id": mock.ANY, + "x": [ + [ 0.0, + 0.1111111111111111, 0.2222222222222222, + 0.3333333333333333, 0.4444444444444444, + 0.5555555555555556, 0.6666666666666666, + 0.7777777777777777, 0.8888888888888888, - 1.1111111111111112, - 1.3333333333333333, - 1.5555555555555554, - 1.7777777777777777, - 2.0, - ], - "interpolator": "linear", - "extrapolate": True, - }, - ) + 1.0, + ] + ], + "y": [ + 0.0, + 0.2222222222222222, + 0.4444444444444444, + 0.6666666666666666, + 0.8888888888888888, + 1.1111111111111112, + 1.3333333333333333, + 1.5555555555555554, + 1.7777777777777777, + 2.0, + ], + "interpolator": "linear", + "extrapolate": True, + } + + # check correct writing to json + self.assertEqual(interp.to_json(), expected_json) + + expected_json["children"] = [y] + # check correct re-creation + self.assertEqual(pybamm.Interpolant._from_json(expected_json), interp) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index 62441f4309..deab4a0cff 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -37,6 +37,9 @@ def test_to_json_error(self): with self.assertRaises(NotImplementedError): func.to_json() + with self.assertRaises(NotImplementedError): + pybamm.Parameter._from_json({}) + class TestFunctionParameter(TestCase): def test_function_parameter_init(self): @@ -121,6 +124,9 @@ def test_to_json_error(self): with self.assertRaises(NotImplementedError): func.to_json() + with self.assertRaises(NotImplementedError): + pybamm.FunctionParameter._from_json({}) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index 3c9de976d6..f11c5d5d10 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -53,6 +53,20 @@ def test_negation(self): pybamm.PrimaryBroadcast(pybamm.PrimaryBroadcast(nega, "test"), "test2"), ) + # Test from_json + input_json = { + "name": "-", + "id": -2659857727954094888, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.Negate._from_json(input_json), nega) + def test_absolute(self): a = pybamm.Symbol("a") absa = pybamm.AbsoluteValue(a) @@ -80,6 +94,20 @@ def test_absolute(self): pybamm.PrimaryBroadcast(pybamm.PrimaryBroadcast(absa, "test"), "test2"), ) + # Test from_json + input_json = { + "name": "abs", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.AbsoluteValue._from_json(input_json), absa) + def test_smooth_absolute_value(self): a = pybamm.StateVector(slice(0, 1)) expr = pybamm.smooth_absolute_value(a, 10) @@ -116,6 +144,11 @@ def test_sign(self): ), ) + # Test from_json + with self.assertRaises(NotImplementedError): + # signs are always scalar/array types in a discretised model + pybamm.Sign._from_json({}) + def test_floor(self): a = pybamm.Symbol("a") floora = pybamm.Floor(a) @@ -130,6 +163,20 @@ def test_floor(self): floorc = pybamm.Floor(c) self.assertEqual(floorc.evaluate(), -4) + # Test from_json + input_json = { + "name": "floor", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.Floor._from_json(input_json), floora) + def test_ceiling(self): a = pybamm.Symbol("a") ceila = pybamm.Ceiling(a) @@ -144,6 +191,20 @@ def test_ceiling(self): ceilc = pybamm.Ceiling(c) self.assertEqual(ceilc.evaluate(), -3) + # Test from_json + input_json = { + "name": "ceil", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.Ceiling._from_json(input_json), ceila) + def test_gradient(self): # gradient of scalar symbol should fail a = pybamm.Symbol("a") @@ -711,6 +772,9 @@ def test_to_from_json(self): with self.assertRaises(NotImplementedError): spatial_vec.to_json() + with self.assertRaises(NotImplementedError): + pybamm.SpatialOperator._from_json({}) + # ExplicitTimeIntegral expr = pybamm.ExplicitTimeIntegral(pybamm.Parameter("param"), pybamm.Scalar(1)) diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 53924b8c2b..7ef55bd2f3 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -3,6 +3,7 @@ # from tests import TestCase import tests +import json import os import unittest import unittest.mock as mock @@ -305,6 +306,17 @@ def test_get_pybamm_class(self): self.assertIsInstance(mesh_class, pybamm.Mesh) + with self.assertRaises(Exception): + unrecognised_symbol = { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scale", + "name": "5.0", + "id": mock.ANY, + "value": 5.0, + "children": [], + } + Serialise()._get_pybamm_class(unrecognised_symbol) + def test_reconstruct_symbol(self): scalar, scalar_dict = scalar_var_dict() @@ -456,9 +468,7 @@ def test_save_load_model(self): # default save where filename isn't provided Serialise().save_model(model) - filename = ( - "test_spm_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M_%S") + ".json" - ) + filename = "test_spm_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M") + ".json" self.assertTrue(os.path.exists(filename)) os.remove(filename) @@ -480,6 +490,18 @@ def test_save_load_model(self): newest_model = Serialise().load_model( "test_model.json", battery_model=pybamm.lithium_ion.SPM ) + + # Test for error if no model type is provided + with open("test_model.json", "r") as f: + model_data = json.load(f) + del model_data["py/object"] + + with open("test_model.json", "w") as f: + json.dump(model_data, f) + + with self.assertRaises(TypeError): + Serialise().load_model("test_model.json") + os.remove("test_model.json") # check new model solves From d5dd21da07eb1b81be11eceadda0a79978ba6a9e Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Mon, 2 Oct 2023 15:36:42 +0000 Subject: [PATCH 15/29] Fix minor style issues --- pybamm/expression_tree/broadcasts.py | 4 ++-- pybamm/expression_tree/operations/serialise.py | 11 ++++++----- pybamm/expression_tree/parameter.py | 10 ++++++---- pybamm/expression_tree/unary_operators.py | 6 ++++-- pybamm/models/base_model.py | 8 ++++---- pybamm/models/event.py | 4 +++- pybamm/plotting/quick_plot.py | 2 +- pybamm/simulation.py | 11 ++++++----- pybamm/solvers/base_solver.py | 2 +- tests/unit/test_expression_tree/test_array.py | 2 +- tests/unit/test_expression_tree/test_matrix.py | 2 +- tests/unit/test_serialisation/test_serialisation.py | 11 +++++------ 12 files changed, 40 insertions(+), 33 deletions(-) diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index a9bd5c2ee2..50afa526e5 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -52,13 +52,13 @@ def _diff(self, variable): def to_json(self): raise NotImplementedError( - "pybamm.Broadcast: Serialisation is only implemented for discretised models." + "pybamm.Broadcast: Serialisation is only implemented for discretised models" ) @classmethod def _from_json(cls, snippet): raise NotImplementedError( - "pybamm.Broadcast: Please use a discretised model when reading in from JSON." + "pybamm.Broadcast: Please use a discretised model when reading in from JSON" ) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index aa84db631e..e3b3d38472 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -101,7 +101,7 @@ def save_model( The desired name of the JSON file. If no name is provided, one will be created based on the model name, and the current datetime. """ - if model.is_discretised == False: + if model.is_discretised is False: raise NotImplementedError( "PyBaMM can only serialise a discretised, ready-to-solve model." ) @@ -161,14 +161,15 @@ def load_model( filename: str Path to the JSON file containing the serialised model file battery_model: :class:`pybamm.BaseModel` (optional) - PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override - any model names within the file. If None, the function will look for the saved object - path, present if the original model came from PyBaMM. + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will + override any model names within the file. If None, the function will look + for the saved object path, present if the original model came from PyBaMM. Returns ------- :class:`pybamm.BaseModel` - A PyBaMM model object, of type specified either in the JSON or in `battery_model`. + A PyBaMM model object, of type specified either in the JSON or in + `battery_model`. """ with open(filename, "r") as f: diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index abf50faa75..afbfe8ac37 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -51,13 +51,13 @@ def to_equation(self): def to_json(self): raise NotImplementedError( - "pybamm.Parameter: Serialisation is only implemented for discretised models." + "pybamm.Parameter: Serialisation is only implemented for discretised models" ) @classmethod def _from_json(cls, snippet): raise NotImplementedError( - "pybamm.Parameter: Please use a discretised model when reading in from JSON." + "pybamm.Parameter: Please use a discretised model when reading in from JSON" ) @@ -235,11 +235,13 @@ def to_equation(self): def to_json(self): raise NotImplementedError( - "pybamm.FunctionParameter: Serialisation is only implemented for discretised models." + "pybamm.FunctionParameter:" + "Serialisation is only implemented for discretised models." ) @classmethod def _from_json(cls, snippet): raise NotImplementedError( - "pybamm.FunctionParameter: Please use a discretised model when reading in from JSON." + "pybamm.FunctionParameter:" + "Please use a discretised model when reading in from JSON." ) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index cc2b2a434e..7aadae412c 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -426,13 +426,15 @@ def diff(self, variable): def to_json(self): raise NotImplementedError( - "pybamm.SpatialOperator: Serialisation is only implemented for discretised models." + "pybamm.SpatialOperator:" + "Serialisation is only implemented for discretised models." ) @classmethod def _from_json(cls, snippet): raise NotImplementedError( - "pybamm.SpatialOperator: Please use a discretised model when reading in from JSON." + "pybamm.SpatialOperator:" + "Please use a discretised model when reading in from JSON." ) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 46645f992d..32a8a27258 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -1180,7 +1180,7 @@ def save_model(self, filename=None, mesh=None, variables=None): if variables and not mesh: warnings.warn( """ - Serialisation: Variables are being saved without a mesh. + Serialisation: Variables are being saved without a mesh. Plotting may not be available. """, pybamm.ModelWarning, @@ -1198,9 +1198,9 @@ def load_model(filename, battery_model: BaseModel = None): filename: str Path to the JSON file containing the serialised model file battery_model: :class: pybamm.BaseBatteryModel, optional - PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override - any model names within the file. If None, the function will look for the saved object - path, present if the original model came from PyBaMM. + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will + override any model names within the file. If None, the function will look + for the saved object path, present if the original model came from PyBaMM. """ return Serialise().load_model(filename, battery_model) diff --git a/pybamm/models/event.py b/pybamm/models/event.py index 105106c470..5bba4cd14b 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -97,7 +97,9 @@ def to_json(self): See :meth:`pybamm.Serialise._SymbolEncoder.default()` """ - # event_type contains string name, for JSON readability, and value for deserialisation. + # event_type contains string name, for JSON readability, + # and value for deserialisation. + json_dict = { "name": self._name, "event_type": [str(self._event_type), self._event_type.value], diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 3f55648225..bfe46b8ed0 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -156,7 +156,7 @@ def __init__( # check variables have been provided after any serialisation if any(len(m.variables) == 0 for m in models): - raise AttributeError(f"Variables not provided by the serialised model") + raise AttributeError("Variables not provided by the serialised model") self.n_rows = n_rows or int( len(output_variables) // np.sqrt(len(output_variables)) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 4a71e819bd..4118118533 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -293,9 +293,10 @@ def update_new_model_events(self, new_model, op): # figure out whether the voltage event is greater than the starting # voltage (charge) or less (discharge) and set the sign of the # event accordingly - if (isinstance(op.value, pybamm.Interpolant) or - isinstance(op.value, pybamm.Multiplication)): - inpt = {"start time":0} + if isinstance(op.value, pybamm.Interpolant) or isinstance( + op.value, pybamm.Multiplication + ): + inpt = {"start time": 0} init_curr = op.value.evaluate(t=0, inputs=inpt).flatten()[0] sign = np.sign(init_curr) else: @@ -1207,8 +1208,8 @@ def save_model( tools will not be availble. Will automatically save meshes as well, required for plotting tools. filename: str, optional - The desired name of the JSON file. If no name is provided, one will be created - based on the model name, and the current datetime. + The desired name of the JSON file. If no name is provided, one will be + created based on the model name, and the current datetime. """ mesh = self.mesh if (mesh or variables) else None variables = self.built_model.variables if variables else None diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 13f8a22f34..cabe36a108 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -707,7 +707,7 @@ def solve( # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: if not isinstance(self, pybamm.DummySolver): - # check a discretised model without original paramaters is not being used + # check for a discretised model without original parameters if not ( model.concatenated_rhs is not None or model.concatenated_algebraic is not None diff --git a/tests/unit/test_expression_tree/test_array.py b/tests/unit/test_expression_tree/test_array.py index 885c5e0851..b75c313f47 100644 --- a/tests/unit/test_expression_tree/test_array.py +++ b/tests/unit/test_expression_tree/test_array.py @@ -47,7 +47,7 @@ def test_to_from_json(self): json_dict = { "name": "Array of shape (3, 1)", - "id": mock.ANY, # The value of the ID will change, but want to check it is present + "id": mock.ANY, "domains": { "primary": [], "secondary": [], diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 2c3d2379ab..055902b15e 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -44,7 +44,7 @@ def test_to_from_json(self): arr = pybamm.Matrix(csr_matrix([[0, 1, 0, 0], [0, 0, 0, 1]])) json_dict = { "name": "Sparse Matrix (2, 4)", - "id": mock.ANY, # The value of the ID will change, but want to check it is present + "id": mock.ANY, "domains": { "primary": [], "secondary": [], diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 7ef55bd2f3..6ae39c05cc 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -2,7 +2,6 @@ # Tests for the serialisation class # from tests import TestCase -import tests import json import os import unittest @@ -273,7 +272,7 @@ def test_deconstruct_pybamm_dicts(self): ser_dict = { "rod": { "symbol_x": { - "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", # noqa: E501 "py/id": mock.ANY, "name": "x", "id": mock.ANY, @@ -342,7 +341,7 @@ def test_reconstruct_expression_tree(self): }, "children": [ { - "py/object": "pybamm.expression_tree.binary_operators.Multiplication", + "py/object": "pybamm.expression_tree.binary_operators.Multiplication", # noqa: E501 "py/id": 139691619709232, "name": "*", "id": 6094209803352873499, @@ -362,7 +361,7 @@ def test_reconstruct_expression_tree(self): "children": [], }, { - "py/object": "pybamm.expression_tree.state_vector.StateVector", + "py/object": "pybamm.expression_tree.state_vector.StateVector", # noqa: E501 "py/id": 139691619589760, "name": "y[0:1]", "id": 5063056989669636089, @@ -424,7 +423,7 @@ def test_reconstruct_pybamm_dict(self): ser_dict = { "rod": { "symbol_x": { - "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", # noqa: E501 "py/id": mock.ANY, "name": "x", "id": mock.ANY, @@ -506,7 +505,7 @@ def test_save_load_model(self): # check new model solves newest_solver = newest_model.default_solver - newest_solution = newest_solver.solve(newest_model, [0, 3600]) + newest_solver.solve(newest_model, [0, 3600]) def test_serialised_model_plotting(self): # models without a mesh From 6d63732f1819d50b2435a055058f46d673c444e6 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 5 Oct 2023 09:48:46 +0000 Subject: [PATCH 16/29] Remove accidental SpatialOperator.diff() addition --- pybamm/expression_tree/unary_operators.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 7aadae412c..8745e5f33c 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -419,11 +419,6 @@ class with a :class:`Matrix` def __init__(self, name, child, domains=None): super().__init__(name, child, domains) - def diff(self, variable): - """See :meth:`pybamm.Symbol.diff()`.""" - # We shouldn't need this - raise NotImplementedError - def to_json(self): raise NotImplementedError( "pybamm.SpatialOperator:" From 0cc0aeeb323d713e302904b4572af745f3f18425 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 19 Oct 2023 16:24:42 -0700 Subject: [PATCH 17/29] Edits after review * Add pybamm version to JSON file * Re-word missing variable message * Refactor unary_operator _from_json() --- .../expression_tree/operations/serialise.py | 1 + pybamm/expression_tree/unary_operators.py | 28 ++----------------- pybamm/plotting/quick_plot.py | 2 +- .../test_expression_tree/test_interpolant.py | 3 +- .../test_unary_operators.py | 4 +-- .../test_serialisation/test_serialisation.py | 2 +- 6 files changed, 8 insertions(+), 32 deletions(-) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index e3b3d38472..14ff251b6a 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -109,6 +109,7 @@ def save_model( model_json = { "py/object": str(type(model))[8:-2], "py/id": id(model), + "pybamm_version": pybamm.__version__, "name": model.name, "options": model.options, "bounds": [bound.tolist() for bound in model.bounds], diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 8745e5f33c..4c047cf0e6 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -35,13 +35,13 @@ def __init__(self, name, child, domains=None): self.child = self.children[0] @classmethod - def _from_json(cls, name, snippet: dict): + def _from_json(cls, snippet: dict): """Use to instantiate when deserialising""" instance = cls.__new__(cls) super(UnaryOperator, instance).__init__( - name, + snippet["name"], snippet["children"], domains=snippet["domains"], ) @@ -114,12 +114,6 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("-", child) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = super()._from_json("-", snippet) - return instance - def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{}{!s}".format(self.name, self.child) @@ -150,12 +144,6 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("abs", child) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = super()._from_json("abs", snippet) - return instance - def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return sign(self.child) * self.child.diff(variable) @@ -216,12 +204,6 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("floor", child) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = super()._from_json("floor", snippet) - return instance - def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return pybamm.Scalar(0) @@ -244,12 +226,6 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("ceil", child) - @classmethod - def _from_json(cls, snippet: dict): - """See :meth:`pybamm.UnaryOperator._from_json()`.""" - instance = super()._from_json("ceil", snippet) - return instance - def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return pybamm.Scalar(0) diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index bfe46b8ed0..584f9ef1be 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -156,7 +156,7 @@ def __init__( # check variables have been provided after any serialisation if any(len(m.variables) == 0 for m in models): - raise AttributeError("Variables not provided by the serialised model") + raise AttributeError("No variables to plot") self.n_rows = n_rows or int( len(output_variables) // np.sqrt(len(output_variables)) diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 0b5ca5f64a..92e9ef86c2 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -326,12 +326,11 @@ def test_processing(self): self.assertEqual(interp, interp.new_copy()) - def test_to_json_error(self): + def test_to_json(self): x = np.linspace(0, 1, 10) y = pybamm.StateVector(slice(0, 2)) interp = pybamm.Interpolant(x, 2 * x, y) - print(interp.children) expected_json = { "name": "interpolating_function", "id": mock.ANY, diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index f11c5d5d10..7e6c71e1dc 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -56,7 +56,7 @@ def test_negation(self): # Test from_json input_json = { "name": "-", - "id": -2659857727954094888, + "id": mock.ANY, "domains": { "primary": [], "secondary": [], @@ -749,7 +749,7 @@ def test_to_from_json(self): self.assertEqual(un.to_json(), un_json) un_json["children"] = [a] - self.assertEqual(pybamm.UnaryOperator._from_json("unary test", un_json), un) + self.assertEqual(pybamm.UnaryOperator._from_json(un_json), un) # Index vec = pybamm.StateVector(slice(0, 5)) diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 6ae39c05cc..533baa718f 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -481,7 +481,7 @@ def test_save_load_model(self): # check an error is raised when plotting the solution with self.assertRaisesRegex( AttributeError, - "Variables not provided by the serialised model", + "No variables to plot", ): new_solution.plot() From 616c0d8acd9c4521eef41170228e14a00242c287 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 20 Oct 2023 14:52:19 -0700 Subject: [PATCH 18/29] Serialisation: fix integration tests --- .../examples/notebooks/models/saving_models.ipynb | 8 ++++++-- pybamm/expression_tree/operations/serialise.py | 14 +++++++++++++- .../full_battery_models/base_battery_model.py | 4 ---- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb index 85ca516a59..c3f72bc4e4 100644 --- a/docs/source/examples/notebooks/models/saving_models.ipynb +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -65,7 +65,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "raises-exception" + ] + }, "outputs": [], "source": [ "dfn_models = [dfn_model, new_dfn_model]\n", @@ -261,7 +265,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.6" }, "orig_nbformat": 4 }, diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index 14ff251b6a..b54f7b1078 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -178,7 +178,7 @@ def load_model( recon_model_dict = { "name": model_data["name"], - "options": model_data["options"], + "options": self._convert_options(model_data["options"]), "bounds": tuple(np.array(bound) for bound in model_data["bounds"]), "concatenated_rhs": self._reconstruct_expression_tree( model_data["concatenated_rhs"] @@ -383,3 +383,15 @@ def recurse(obj): return obj return recurse(obj) + + def _convert_options(self, d): + """ + Converts a dictionary with nested lists to nested tuples, + used to convert model options back into correct format + """ + if isinstance(d, dict): + return {k: self._convert_options(v) for k, v in d.items()} + elif isinstance(d, list): + return tuple(self._convert_options(item) for item in d) + else: + return d diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index cd0b256113..d5593fa55e 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -630,10 +630,6 @@ def __init__(self, extra_options): ]: # some options accept non-strings value = (value,) else: - # serialised options save tuples as lists which need to be converted - if isinstance(value, list) and len(value) == 2: - value = tuple(tuple(v) if len(v) == 2 else v for v in value) - if not ( ( option From 8e3271826895a0d4018c73906e3ae4e284ff3cfa Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 20 Oct 2023 16:53:44 -0700 Subject: [PATCH 19/29] Reduce test tolerance of sei_asymmetric_ec_reaction_limited --- tests/integration/test_models/standard_model_tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index d0e38501c9..ccf12c6143 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -152,7 +152,10 @@ def test_serialisation(self, solver=None, t_eval=None): else: new_solver = new_model.default_solver - if isinstance(new_model, pybamm.lithium_ion.BaseModel): + if ( + isinstance(new_model, pybamm.lithium_ion.BaseModel) + and new_model.options["SEI"] != "ec reaction limited (asymmetric)" + ): new_solver.rtol = 1e-8 new_solver.atol = 1e-8 accuracy = 6 From 1e16b92de780ba42db4909bc87eb7b7dcf059949 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 20 Oct 2023 17:56:36 -0700 Subject: [PATCH 20/29] fix: change serialisation test accuracy Required for macOS python<3.11 --- tests/integration/test_models/standard_model_tests.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index ccf12c6143..d4074e15ef 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -152,15 +152,11 @@ def test_serialisation(self, solver=None, t_eval=None): else: new_solver = new_model.default_solver - if ( - isinstance(new_model, pybamm.lithium_ion.BaseModel) - and new_model.options["SEI"] != "ec reaction limited (asymmetric)" - ): + if isinstance(new_model, pybamm.lithium_ion.BaseModel): new_solver.rtol = 1e-8 new_solver.atol = 1e-8 - accuracy = 6 - else: - accuracy = 5 + + accuracy = 5 Crate = abs( self.parameter_values["Current function [A]"] From 62a46ef683504c117668abad0a611a12a79249fb Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 7 Nov 2023 18:31:37 -0800 Subject: [PATCH 21/29] Additional tests for codecov --- .../test_expression_tree/test_interpolant.py | 16 +++- tests/unit/test_meshes/test_meshes.py | 2 +- .../test_meshes/test_scikit_fem_submesh.py | 65 ++++++++++++++++ tests/unit/test_models/test_base_model.py | 77 +++++++++++++++---- .../test_base_battery_model.py | 24 ++++++ 5 files changed, 169 insertions(+), 15 deletions(-) diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 92e9ef86c2..5fa078cffc 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -326,7 +326,7 @@ def test_processing(self): self.assertEqual(interp, interp.new_copy()) - def test_to_json(self): + def test_to_from_json(self): x = np.linspace(0, 1, 10) y = pybamm.StateVector(slice(0, 2)) interp = pybamm.Interpolant(x, 2 * x, y) @@ -371,6 +371,20 @@ def test_to_json(self): # check correct re-creation self.assertEqual(pybamm.Interpolant._from_json(expected_json), interp) + # test to_from_json for 2d x & y + x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.01)) + xx, yy = np.meshgrid(x[0], x[1], indexing="ij") + z = np.sin(xx**2 + yy**2) + var1 = pybamm.StateVector(slice(0, 1)) + var2 = pybamm.StateVector(slice(1, 2)) + # linear + interp = pybamm.Interpolant(x, z, (var1, var2), interpolator="linear") + + interp2d_json = interp.to_json() + interp2d_json["children"] = (var1, var2) + + self.assertEqual(pybamm.Interpolant._from_json(interp2d_json), interp) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_meshes/test_meshes.py b/tests/unit/test_meshes/test_meshes.py index 000ec729a5..3066d14534 100644 --- a/tests/unit/test_meshes/test_meshes.py +++ b/tests/unit/test_meshes/test_meshes.py @@ -390,7 +390,7 @@ def test_1plus1D_tabs_right_left(self): # positive tab should be "left" self.assertEqual(mesh["current collector"].tabs["positive tab"], "left") - def test_to_from_json(self): + def test_to_json(self): r = pybamm.SpatialVariable( "r", domain=["negative particle"], coord_sys="spherical polar" ) diff --git a/tests/unit/test_meshes/test_scikit_fem_submesh.py b/tests/unit/test_meshes/test_scikit_fem_submesh.py index 88bde7941f..1e0839250e 100644 --- a/tests/unit/test_meshes/test_scikit_fem_submesh.py +++ b/tests/unit/test_meshes/test_scikit_fem_submesh.py @@ -218,6 +218,71 @@ def test_to_json(self): self.assertEqual(mesh_json, expected_json) + # test Uniform2DSubMesh serialisation + + submesh = mesh["current collector"].to_json() + + expected_submesh = { + "edges": { + "y": [ + 0.0, + 0.02666666666666667, + 0.05333333333333334, + 0.08, + 0.10666666666666667, + 0.13333333333333333, + 0.16, + 0.18666666666666668, + 0.21333333333333335, + 0.24000000000000002, + 0.26666666666666666, + 0.29333333333333333, + 0.32, + 0.3466666666666667, + 0.37333333333333335, + 0.4, + ], + "z": [ + 0.0, + 0.021739130434782608, + 0.043478260869565216, + 0.06521739130434782, + 0.08695652173913043, + 0.10869565217391304, + 0.13043478260869565, + 0.15217391304347827, + 0.17391304347826086, + 0.19565217391304346, + 0.21739130434782608, + 0.2391304347826087, + 0.2608695652173913, + 0.2826086956521739, + 0.30434782608695654, + 0.32608695652173914, + 0.34782608695652173, + 0.3695652173913043, + 0.3913043478260869, + 0.41304347826086957, + 0.43478260869565216, + 0.45652173913043476, + 0.4782608695652174, + 0.5, + ], + }, + "coord_sys": "cartesian", + "tabs": { + "negative": {"y_centre": 0.1, "z_centre": 0.5, "width": 0.1}, + "positive": {"y_centre": 0.3, "z_centre": 0.5, "width": 0.1}, + }, + } + + self.assertEqual(submesh, expected_submesh) + + new_submesh = pybamm.ScikitUniform2DSubMesh._from_json(submesh) + + for x, y in zip(mesh['current collector'].edges, new_submesh.edges): + np.testing.assert_array_equal(x, y) + class TestScikitFiniteElementChebyshev2DSubMesh(TestCase): def test_mesh_creation(self): diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 1274d1a7bf..438b7391a7 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -984,29 +984,80 @@ def test_timescale_lengthscale_get_set_not_implemented(self): model.length_scales = 1 def test_save_load_model(self): + # Set up model model = pybamm.BaseModel() - c = pybamm.Variable("c") - model.rhs = {c: -c} - model.initial_conditions = {c: 1} - model.variables["c"] = c - model.variables["2c"] = 2 * c + var_scalar = pybamm.Variable("var_scalar") + var_1D = pybamm.Variable("var_1D", domain="negative electrode") + var_2D = pybamm.Variable( + "var_2D", + domain="negative particle", + auxiliary_domains={"secondary": "negative electrode"}, + ) + var_concat_neg = pybamm.Variable("var_concat_neg", domain="negative electrode") + var_concat_sep = pybamm.Variable("var_concat_sep", domain="separator") + var_concat = pybamm.concatenation(var_concat_neg, var_concat_sep) + model.rhs = {var_scalar: -var_scalar, var_1D: -var_1D} + model.algebraic = {var_2D: -var_2D, var_concat: -var_concat} + model.initial_conditions = {var_scalar: 1, var_1D: 1, var_2D: 1, var_concat: 1} + model.variables = { + "var_scalar": var_scalar, + "var_1D": var_1D, + "var_2D": var_2D, + "var_concat_neg": var_concat_neg, + "var_concat_sep": var_concat_sep, + "var_concat": var_concat, + } - # setup and discretise - solution = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) + # Discretise + geometry = { + "negative electrode": {"x_n": {"min": 0, "max": 1}}, + "separator": {"x_s": {"min": 1, "max": 2}}, + "negative particle": {"r_n": {"min": 0, "max": 1}}, + } + submeshes = { + "negative electrode": pybamm.Uniform1DSubMesh, + "separator": pybamm.Uniform1DSubMesh, + "negative particle": pybamm.Uniform1DSubMesh, + } + var_pts = {"x_n": 10, "x_s": 10, "r_n": 5} + mesh = pybamm.Mesh(geometry, submeshes, var_pts) + spatial_methods = { + "negative electrode": pybamm.FiniteVolume(), + "separator": pybamm.FiniteVolume(), + "negative particle": pybamm.FiniteVolume(), + } + disc = pybamm.Discretisation(mesh, spatial_methods) + model_disc = disc.process_model(model, inplace=False) + t = np.linspace(0, 1) + y = np.tile(3 * t, (1 + 30 + 50, 1)) - # save model - model.save_model(filename="test_base_model") + # Find baseline solution + solution = pybamm.Solution(t, y, model_disc, {}) - # raises warning if variables are saved - with self.assertWarns(pybamm.ModelWarning): - model.save_model(filename="test_base_model", variables=model.variables) + # save model + model_disc.save_model(filename="test_base_model") + # load without variables new_model = pybamm.load_model("test_base_model.json") - new_solution = pybamm.ScipySolver().solve(new_model, np.linspace(0, 1)) + new_solution = pybamm.Solution(t, y, new_model, {}) # model solutions match testing.assert_array_equal(solution.all_ys, new_solution.all_ys) + + # raises warning if variables are saved without mesh + with self.assertWarns(pybamm.ModelWarning): + model_disc.save_model( + filename="test_base_model", variables=model_disc.variables + ) + + model_disc.save_model( + filename="test_base_model", variables=model_disc.variables, mesh=mesh + ) + + # load with variables & mesh + new_model = pybamm.load_model("test_base_model.json") + os.remove("test_base_model.json") diff --git a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py index 79c6d8a720..91bcfc28cc 100644 --- a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py @@ -7,6 +7,7 @@ import unittest import io from contextlib import redirect_stdout +import os OPTIONS_DICT = { "surface form": "differential", @@ -449,6 +450,29 @@ def test_option_type(self): model = pybamm.BaseBatteryModel(options) self.assertEqual(model.options, options) + def test_save_load_model(self): + model = ( + pybamm.lithium_ion.SPM() + ) + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + + # save model + model.save_model(filename="test_base_battery_model", mesh=mesh, + variables=model.variables) + + # raises error if variables are saved without mesh + with self.assertRaises(ValueError): + model.save_model(filename="test_base_battery_model", + variables=model.variables) + + os.remove("test_base_battery_model.json") + class TestOptions(TestCase): def test_print_options(self): From 52112332b9bee283c777a54418494c3bd438e925 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 9 Nov 2023 16:59:51 -0800 Subject: [PATCH 22/29] More coverage updates to serialise and 1D meshes --- .../expression_tree/operations/serialise.py | 5 +--- .../test_one_dimensional_submesh.py | 4 ++++ .../test_serialisation/test_serialisation.py | 23 ++++++++++++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index b54f7b1078..cd2ff15c3d 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -238,10 +238,7 @@ def load_model( def _get_pybamm_class(self, snippet: dict): """Find a pybamm class to initialise from object path""" parts = snippet["py/object"].split(".") - try: - module = importlib.import_module(".".join(parts[:-1])) - except Exception as ex: - print(ex) + module = importlib.import_module(".".join(parts[:-1])) class_ = getattr(module, parts[-1]) diff --git a/tests/unit/test_meshes/test_one_dimensional_submesh.py b/tests/unit/test_meshes/test_one_dimensional_submesh.py index a7cafb5e25..514de4248b 100644 --- a/tests/unit/test_meshes/test_one_dimensional_submesh.py +++ b/tests/unit/test_meshes/test_one_dimensional_submesh.py @@ -44,6 +44,10 @@ def test_to_json(self): self.assertEqual(mesh_json, expected_json) + # check tabs work + new_mesh = pybamm.Uniform1DSubMesh._from_json(mesh_json) + self.assertEqual(mesh.tabs, new_mesh.tabs) + class TestUniform1DSubMesh(TestCase): def test_exceptions(self): diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 533baa718f..97299e669d 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -305,7 +305,7 @@ def test_get_pybamm_class(self): self.assertIsInstance(mesh_class, pybamm.Mesh) - with self.assertRaises(Exception): + with self.assertRaises(AttributeError): unrecognised_symbol = { "py/id": mock.ANY, "py/object": "pybamm.expression_tree.scalar.Scale", @@ -443,6 +443,27 @@ def test_reconstruct_pybamm_dict(self): self.assertEqual(new_dict, test_dict) + # test recreation if not passed a dict + test_list = ["left", "right"] + new_list = Serialise()._reconstruct_pybamm_dict(test_list) + + self.assertEqual(test_list, new_list) + + def test_convert_options(self): + options_dict = { + "current collector": "uniform", + "particle phases": ["2", "1"], + "open-circuit potential": [["single", "current sigmoid"], "single"], + } + + options_result = { + "current collector": "uniform", + "particle phases": ("2", "1"), + "open-circuit potential": (("single", "current sigmoid"), "single"), + } + + self.assertEqual(Serialise()._convert_options(options_dict), options_result) + def test_save_load_model(self): model = pybamm.lithium_ion.SPM(name="test_spm") geometry = model.default_geometry From afa187e6bb0e3f33e62ea40ff6f29a05e3c99779 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Thu, 16 Nov 2023 16:48:05 +0000 Subject: [PATCH 23/29] Update CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 483ca91a5e..259980804a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) +## Features + +- Serialisation added so models can be written to/read from JSON ([#3397](https://github.com/pybamm-team/PyBaMM/pull/3397)) + ## Bug fixes - Fixed bug that made identical Experiment steps with different end times crash ([#3516](https://github.com/pybamm-team/PyBaMM/pull/3516)) From b7453175ef944ff0063e17998f53ef0a846c32ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:48:46 +0000 Subject: [PATCH 24/29] style: pre-commit fixes --- pybamm/models/base_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index fa87509f03..ed26a9062a 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -10,7 +10,6 @@ import numpy as np import pybamm -from pybamm.expression_tree.operations.latexify import Latexify from pybamm.expression_tree.operations.serialise import Serialise from pybamm.util import have_optional_dependency From 3fec37ea72b1ca8d65345ba1242b936fcc8e83a2 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 24 Nov 2023 10:14:38 +0000 Subject: [PATCH 25/29] Add error message for experiment --- pybamm/simulation.py | 7 +++++++ tests/unit/test_serialisation/test_serialisation.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 1ad8f0c682..6da17a61b2 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -1199,6 +1199,13 @@ def save_model( mesh = self.mesh if (mesh or variables) else None variables = self.built_model.variables if variables else None + if self.operating_mode == "with experiment": + raise NotImplementedError( + """ + Serialising models coupled to experiments is not yet supported. + """ + ) + if self.built_model: Serialise().save_model( self.built_model, filename=filename, mesh=mesh, variables=variables diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index 97299e669d..e304091b22 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -528,6 +528,18 @@ def test_save_load_model(self): newest_solver = newest_model.default_solver newest_solver.solve(newest_model, [0, 3600]) + def test_save_experiment_model_error(self): + model = pybamm.lithium_ion.SPM() + experiment = pybamm.Experiment(["Discharge at 1C for 1 hour"]) + sim = pybamm.Simulation(model, experiment=experiment) + sim.solve() + + with self.assertRaisesRegex( + NotImplementedError, + "Serialising models coupled to experiments is not yet supported.", + ): + sim.save_model("spm_experiment", mesh=False, variables=False) + def test_serialised_model_plotting(self): # models without a mesh model = pybamm.BaseModel() From 92e7c9042cc35b9a5f591ec8b5073a08275da8cf Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Fri, 24 Nov 2023 11:57:15 +0000 Subject: [PATCH 26/29] Update notebook to suggest build() not solve() --- .../notebooks/models/saving_models.ipynb | 26 +++++-------------- pybamm/simulation.py | 2 +- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb index c3f72bc4e4..9ac76a611e 100644 --- a/docs/source/examples/notebooks/models/saving_models.ipynb +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -13,7 +13,7 @@ "source": [ "Models which are discretised (i.e. ready to solve/ previously solved, see [this notebook](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/spatial_methods/finite-volumes.ipynb) for more information on the pybamm.Discretisation class) can be serialised and saved to a JSON file, ready to be read in again either in PyBaMM, or a different modelling library. \n", "\n", - "In the example below, we build and solve a basic DFN model, and then save the model out to `sim_model_example.json`, which should have appear in the 'models' directory." + "In the example below, we build a basic DFN model, and then save the model out to `sim_model_example.json`, which should have appear in the 'models' directory." ] }, { @@ -28,7 +28,8 @@ "# do the example\n", "dfn_model = pybamm.lithium_ion.DFN()\n", "dfn_sim = pybamm.Simulation(dfn_model)\n", - "dfn_sim.solve([0, 3600])\n", + "# discretise and build the model\n", + "dfn_sim.build()\n", "\n", "dfn_sim.save_model(\"sim_model_example\")" ] @@ -155,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -207,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -228,22 +229,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1] Joel A. E. Andersson, Joris Gillis, Greg Horn, James B. Rawlings, and Moritz Diehl. CasADi – A software framework for nonlinear optimization and optimal control. Mathematical Programming Computation, 11(1):1–36, 2019. doi:10.1007/s12532-018-0139-4.\n", - "[2] Marc Doyle, Thomas F. Fuller, and John Newman. Modeling of galvanostatic charge and discharge of the lithium/polymer/insertion cell. Journal of the Electrochemical society, 140(6):1526–1533, 1993. doi:10.1149/1.2221597.\n", - "[3] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n", - "[4] Scott G. Marquis, Valentin Sulzer, Robert Timms, Colin P. Please, and S. Jon Chapman. An asymptotic derivation of a single particle model with electrolyte. Journal of The Electrochemical Society, 166(15):A3693–A3706, 2019. doi:10.1149/2.0341915jes.\n", - "[5] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "pybamm.print_citations()" ] diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 6da17a61b2..a25653b507 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -1214,7 +1214,7 @@ def save_model( raise NotImplementedError( """ PyBaMM can only serialise a discretised model. - Ensure the model has been built (e.g. run `solve()`) before saving. + Ensure the model has been built (e.g. run `build()`) before saving. """ ) From 04f4230ce6ddb64a88cddb31064b891bc4a4e729 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 28 Nov 2023 11:13:31 +0000 Subject: [PATCH 27/29] Add outputs to example notebook Fixes doctests error --- .../notebooks/models/saving_models.ipynb | 146 ++++++++++++++++-- 1 file changed, 130 insertions(+), 16 deletions(-) diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb index 9ac76a611e..91a6f2ae5c 100644 --- a/docs/source/examples/notebooks/models/saving_models.ipynb +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -18,9 +18,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install pybamm -q # install PyBaMM if it is not installed\n", "import pybamm\n", @@ -43,9 +51,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Recreate the pybamm model from the JSON file\n", "new_dfn_model = pybamm.load_model(\"sim_model_example.json\")\n", @@ -65,13 +84,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "tags": [ "raises-exception" ] }, - "outputs": [], + "outputs": [ + { + "ename": "AttributeError", + "evalue": "No variables to plot", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/pipliggins/Documents/repos/pybamm-local/docs/source/examples/notebooks/models/saving_models.ipynb Cell 7\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m plot_sim\u001b[39m.\u001b[39msolve([\u001b[39m0\u001b[39m, \u001b[39m3600\u001b[39m])\n\u001b[1;32m 6\u001b[0m sims\u001b[39m.\u001b[39mappend(plot_sim)\n\u001b[0;32m----> 8\u001b[0m pybamm\u001b[39m.\u001b[39;49mdynamic_plot(sims, time_unit\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mseconds\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/repos/pybamm-local/pybamm/plotting/dynamic_plot.py:20\u001b[0m, in \u001b[0;36mdynamic_plot\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[39mCreates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[39marguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39m The 'QuickPlot' object that was created\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 19\u001b[0m kwargs_for_class \u001b[39m=\u001b[39m {k: v \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m kwargs\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m k \u001b[39m!=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[0;32m---> 20\u001b[0m plot \u001b[39m=\u001b[39m pybamm\u001b[39m.\u001b[39;49mQuickPlot(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs_for_class)\n\u001b[1;32m 21\u001b[0m plot\u001b[39m.\u001b[39mdynamic_plot(kwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m))\n\u001b[1;32m 22\u001b[0m \u001b[39mreturn\u001b[39;00m plot\n", + "File \u001b[0;32m~/Documents/repos/pybamm-local/pybamm/plotting/quick_plot.py:146\u001b[0m, in \u001b[0;36mQuickPlot.__init__\u001b[0;34m(self, solutions, output_variables, labels, colors, linestyles, shading, figsize, n_rows, time_unit, spatial_unit, variable_limits)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[39m# check variables have been provided after any serialisation\u001b[39;00m\n\u001b[1;32m 145\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(\u001b[39mlen\u001b[39m(m\u001b[39m.\u001b[39mvariables) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mfor\u001b[39;00m m \u001b[39min\u001b[39;00m models):\n\u001b[0;32m--> 146\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mNo variables to plot\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 148\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows \u001b[39m=\u001b[39m n_rows \u001b[39mor\u001b[39;00m \u001b[39mint\u001b[39m(\n\u001b[1;32m 149\u001b[0m \u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m np\u001b[39m.\u001b[39msqrt(\u001b[39mlen\u001b[39m(output_variables))\n\u001b[1;32m 150\u001b[0m )\n\u001b[1;32m 151\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_cols \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(np\u001b[39m.\u001b[39mceil(\u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows))\n", + "\u001b[0;31mAttributeError\u001b[0m: No variables to plot" + ] + } + ], "source": [ "dfn_models = [dfn_model, new_dfn_model]\n", "sims = []\n", @@ -94,9 +127,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "81d8329fab424264bd56c65d53d34f63", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='t', max=3600.0, step=36.0), Output()), _dom_classes=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# using the first simulation, save a new file which includes a list of the model variables\n", "dfn_sim.save_model(\"sim_model_variables\", variables=True)\n", @@ -130,9 +188,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# create the model\n", "spm_model = pybamm.lithium_ion.SPM()\n", @@ -156,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -173,9 +242,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ce5addf4f59c447e97d2fbee633cb6e0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# read back in\n", "new_spm_model = pybamm.load_model(\"example_model.json\")\n", @@ -208,7 +302,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -229,12 +323,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1] Joel A. E. Andersson, Joris Gillis, Greg Horn, James B. Rawlings, and Moritz Diehl. CasADi – A software framework for nonlinear optimization and optimal control. Mathematical Programming Computation, 11(1):1–36, 2019. doi:10.1007/s12532-018-0139-4.\n", + "[2] Marc Doyle, Thomas F. Fuller, and John Newman. Modeling of galvanostatic charge and discharge of the lithium/polymer/insertion cell. Journal of the Electrochemical society, 140(6):1526–1533, 1993. doi:10.1149/1.2221597.\n", + "[3] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n", + "[4] Scott G. Marquis, Valentin Sulzer, Robert Timms, Colin P. Please, and S. Jon Chapman. An asymptotic derivation of a single particle model with electrolyte. Journal of The Electrochemical Society, 166(15):A3693–A3706, 2019. doi:10.1149/2.0341915jes.\n", + "[5] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n", + "\n" + ] + } + ], "source": [ "pybamm.print_citations()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From ca63509060a895aabe812a7a1d2eebc08d8e2633 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 11:16:07 +0000 Subject: [PATCH 28/29] style: pre-commit fixes --- tests/unit/test_serialisation/test_serialisation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index e304091b22..6c43eaa9d7 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -272,7 +272,7 @@ def test_deconstruct_pybamm_dicts(self): ser_dict = { "rod": { "symbol_x": { - "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", # noqa: E501 + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", "py/id": mock.ANY, "name": "x", "id": mock.ANY, @@ -341,7 +341,7 @@ def test_reconstruct_expression_tree(self): }, "children": [ { - "py/object": "pybamm.expression_tree.binary_operators.Multiplication", # noqa: E501 + "py/object": "pybamm.expression_tree.binary_operators.Multiplication", "py/id": 139691619709232, "name": "*", "id": 6094209803352873499, @@ -361,7 +361,7 @@ def test_reconstruct_expression_tree(self): "children": [], }, { - "py/object": "pybamm.expression_tree.state_vector.StateVector", # noqa: E501 + "py/object": "pybamm.expression_tree.state_vector.StateVector", "py/id": 139691619589760, "name": "y[0:1]", "id": 5063056989669636089, @@ -423,7 +423,7 @@ def test_reconstruct_pybamm_dict(self): ser_dict = { "rod": { "symbol_x": { - "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", # noqa: E501 + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", "py/id": mock.ANY, "name": "x", "id": mock.ANY, From df35b91c894a42c1618b6a50375e4e6bc27b8d60 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 28 Nov 2023 11:23:56 +0000 Subject: [PATCH 29/29] Fix ruff errors --- pybamm/expression_tree/operations/serialise.py | 10 ++++++---- pybamm/simulation.py | 7 +++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py index cd2ff15c3d..c7768217a3 100644 --- a/pybamm/expression_tree/operations/serialise.py +++ b/pybamm/expression_tree/operations/serialise.py @@ -7,6 +7,8 @@ import numpy as np import re +from typing import Optional + class Serialise: """ @@ -78,9 +80,9 @@ class _EmptyDict(dict): def save_model( self, model: pybamm.BaseModel, - mesh: pybamm.Mesh = None, - variables: pybamm.FuzzyDict = None, - filename: str = None, + mesh: Optional[pybamm.Mesh] = None, + variables: Optional[pybamm.FuzzyDict] = None, + filename: Optional[str] = None, ): """Saves a discretised model to a JSON file. @@ -142,7 +144,7 @@ def save_model( json.dump(model_json, f) def load_model( - self, filename: str, battery_model: pybamm.BaseModel = None + self, filename: str, battery_model: Optional[pybamm.BaseModel] = None ) -> pybamm.BaseModel: """ Loads a discretised, ready to solve model into PyBaMM. diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 4fe9c32924..83a386fe98 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -10,6 +10,7 @@ from functools import lru_cache from datetime import timedelta from pybamm.util import have_optional_dependency +from typing import Optional from pybamm.expression_tree.operations.serialise import Serialise @@ -795,7 +796,9 @@ def solve( # Hacky patch to allow correct processing of end_time and next_starting time # For efficiency purposes, op_conds treats identical steps as the same object # regardless of the initial time. Should be refactored as part of #3176 - op_conds_unproc = self.experiment.operating_conditions_steps_unprocessed[idx] + op_conds_unproc = ( + self.experiment.operating_conditions_steps_unprocessed[idx] + ) start_time = current_solution.t[-1] @@ -1192,7 +1195,7 @@ def save(self, filename): def save_model( self, - filename: str = None, + filename: Optional[str] = None, mesh: bool = False, variables: bool = False, ):