Skip to content

Commit

Permalink
Merge pull request #3397 from pipliggins/serialisation
Browse files Browse the repository at this point in the history
Serialisation of models
  • Loading branch information
martinjrobins authored Nov 28, 2023
2 parents c86f8fe + df35b91 commit 25b1e75
Show file tree
Hide file tree
Showing 56 changed files with 3,270 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features

- Serialisation added so models can be written to/read from JSON ([#3397](https://github.com/pybamm-team/PyBaMM/pull/3397))

## Bug fixes

- Fixed a bug where simulations using the CasADi-based solvers would fail randomly with the half-cell model ([#3494](https://github.com/pybamm-team/PyBaMM/pull/3494))
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/expression_tree/operations/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Classes and functions that operate on the expression tree
evaluate
jacobian
convert_to_casadi
serialise
unpack_symbol
5 changes: 5 additions & 0 deletions docs/source/api/expression_tree/operations/serialise.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Serialise
=========

.. autoclass:: pybamm.expression_tree.operations.serialise.Serialise
:members:
1 change: 1 addition & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The notebooks are organised into subfolders, and can be viewed in the galleries
notebooks/models/MSMR.ipynb
notebooks/models/pouch-cell-model.ipynb
notebooks/models/rate-capability.ipynb
notebooks/models/saving_models.ipynb
notebooks/models/SEI-on-cracks.ipynb
notebooks/models/simulating-ORegan-2022-parameter-set.ipynb
notebooks/models/SPM.ipynb
Expand Down
376 changes: 376 additions & 0 deletions docs/source/examples/notebooks/models/saving_models.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@
UserSupplied2DSubMesh,
)

#
# Serialisation
#
from .models.base_model import load_model

#
# Spatial Methods
#
Expand Down
48 changes: 48 additions & 0 deletions pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ def __init__(
name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains
)

@classmethod
def _from_json(cls, snippet: dict):
instance = cls.__new__(cls)

if isinstance(snippet["entries"], dict):
matrix = csr_matrix(
(
snippet["entries"]["data"],
snippet["entries"]["row_indices"],
snippet["entries"]["column_pointers"],
),
shape=snippet["entries"]["shape"],
)
else:
matrix = snippet["entries"]

instance.__init__(
matrix,
name=snippet["name"],
domains=snippet["domains"],
)

return instance

@property
def entries(self):
return self._entries
Expand Down Expand Up @@ -129,6 +153,30 @@ def to_equation(self):
entries_list = self.entries.tolist()
return sympy.Array(entries_list)

def to_json(self):
"""
Method to serialise an Array object into JSON.
"""

if isinstance(self.entries, np.ndarray):
matrix = self.entries.tolist()
elif isinstance(self.entries, csr_matrix):
matrix = {
"shape": self.entries.shape,
"data": self.entries.data.tolist(),
"row_indices": self.entries.indices.tolist(),
"column_pointers": self.entries.indptr.tolist(),
}

json_dict = {
"name": self.name,
"id": self.id,
"domains": self.domains,
"entries": matrix,
}

return json_dict


def linspace(start, stop, num=50, **kwargs):
"""
Expand Down
26 changes: 26 additions & 0 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def __init__(self, name, left, right):
self.left = self.children[0]
self.right = self.children[1]

@classmethod
def _from_json(cls, snippet: dict):
"""Use to instantiate when deserialising; discretisation has
already occured so pre-processing of binaries is not necessary."""

instance = cls.__new__(cls)

super(BinaryOperator, instance).__init__(
snippet["name"],
children=[snippet["children"][0], snippet["children"][1]],
domains=snippet["domains"],
)
instance.left = instance.children[0]
instance.right = instance.children[1]

return instance

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
# Possibly add brackets for clarity
Expand Down Expand Up @@ -156,6 +173,15 @@ def to_equation(self):
eq2 = child2.to_equation()
return self._sympy_operator(eq1, eq2)

def to_json(self):
"""
Method to serialise a BinaryOperator object into JSON.
"""

json_dict = {"name": self.name, "id": self.id, "domains": self.domains}

return json_dict


class Power(BinaryOperator):
"""
Expand Down
11 changes: 11 additions & 0 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def _diff(self, variable):
# Differentiate the child and broadcast the result in the same way
return self._unary_new_copy(self.child.diff(variable))

def to_json(self):
raise NotImplementedError(
"pybamm.Broadcast: Serialisation is only implemented for discretised models"
)

@classmethod
def _from_json(cls, snippet):
raise NotImplementedError(
"pybamm.Broadcast: Please use a discretised model when reading in from JSON"
)


class PrimaryBroadcast(Broadcast):
"""
Expand Down
74 changes: 74 additions & 0 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None):

super().__init__(name, children, domains=domains)

@classmethod
def _from_json(cls, *children, name, domains, concat_fun=None):
"""Creates a new Concatenation instance from a json object"""
instance = cls.__new__(cls)

instance.concatenation_function = concat_fun

super(Concatenation, instance).__init__(name, children, domains=domains)

return instance

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
out = self.name + "("
Expand Down Expand Up @@ -183,6 +194,18 @@ def __init__(self, *children):
concat_fun=np.concatenate
)

@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.Concatenation._from_json()`."""
instance = super()._from_json(
*snippet["children"],
name="numpy_concatenation",
domains=snippet["domains"],
concat_fun=np.concatenate
)

return instance

def _concatenation_jac(self, children_jacs):
"""See :meth:`pybamm.Concatenation.concatenation_jac()`."""
children = self.children
Expand Down Expand Up @@ -251,6 +274,31 @@ def __init__(self, children, full_mesh, copy_this=None):
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts

@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.Concatenation._from_json()`."""
instance = super()._from_json(
*snippet["children"],
name="domain_concatenation",
domains=snippet["domains"]
)

def repack_defaultDict(slices):
slices = defaultdict(list, slices)
for domain, sls in slices.items():
sls = [slice(s["start"], s["stop"], s["step"]) for s in sls]
slices[domain] = sls
return slices

instance._size = snippet["size"]
instance._slices = repack_defaultDict(snippet["slices"])
instance._children_slices = [
repack_defaultDict(s) for s in snippet["children_slices"]
]
instance.secondary_dimensions_npts = snippet["secondary_dimensions_npts"]

return instance

def _get_auxiliary_domain_repeats(self, auxiliary_domains):
"""Helper method to read the 'auxiliary_domain' meshes."""
mesh_pts = 1
Expand Down Expand Up @@ -316,6 +364,32 @@ def _concatenation_new_copy(self, children):
)
return new_symbol

def to_json(self):
"""
Method to serialise a DomainConcatenation object into JSON.
"""

def unpack_defaultDict(slices):
slices = dict(slices)
for domain, sls in slices.items():
sls = [{"start": s.start, "stop": s.stop, "step": s.step} for s in sls]
slices[domain] = sls
return slices

json_dict = {
"name": self.name,
"id": self.id,
"domains": self.domains,
"slices": unpack_defaultDict(self._slices),
"size": self._size,
"children_slices": [
unpack_defaultDict(child_slice) for child_slice in self._children_slices
],
"secondary_dimensions_npts": self.secondary_dimensions_npts,
}

return json_dict


class SparseStack(Concatenation):
"""
Expand Down
Loading

0 comments on commit 25b1e75

Please sign in to comment.