Skip to content

Commit

Permalink
#2382 SPMe example working
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Oct 20, 2022
1 parent d4ecb7e commit c9bda6c
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 186 deletions.
99 changes: 3 additions & 96 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
#
import pybamm
import numpy as np
from collections import defaultdict, OrderedDict
from scipy.sparse import block_diag, csc_matrix, csr_matrix
from scipy.sparse.linalg import inv
from collections import defaultdict


def has_bc_of_form(symbol, side, bcs, form):
Expand Down Expand Up @@ -193,10 +191,8 @@ def process_model(
)
self.set_internal_boundary_conditions(model)

kwargs = {}

# Keep a record of y_slices in the model
y_slices = self.y_slices_explicit
y_slices = self.y_slices
# Keep a record of the bounds in the model
bounds = self.bounds

Expand Down Expand Up @@ -257,10 +253,6 @@ def process_model(
# create a copy of the original model
model_disc = model.new_copy(equations=discretised_equations)

# Create mass matrix
pybamm.logger.verbose("Create mass matrix for {}".format(model.name))
mass_matrix, mass_matrix_inv = self.create_mass_matrix(model_disc)

# Check that resulting model makes sense
if check_model:
pybamm.logger.verbose("Performing model checks for {}".format(model.name))
Expand All @@ -281,7 +273,6 @@ def set_variable_slices(self, variables):
"""
# Set up y_slices and bounds
y_slices = defaultdict(list)
y_slices_explicit = defaultdict(list)
start = 0
end = 0
lower_bounds = []
Expand All @@ -293,7 +284,7 @@ def set_variable_slices(self, variables):
start_ = start
spatial_method = self.spatial_methods[variable.domain[0]]
children = variable.children
meshes = OrderedDict()
meshes = {}
for child in children:
meshes[child] = [spatial_method.mesh[dom] for dom in child.domain]
sec_points = spatial_method._get_auxiliary_domain_repeats(
Expand All @@ -305,15 +296,13 @@ def set_variable_slices(self, variables):
end += domain_mesh.npts_for_broadcast_to_nodes
# Add to slices
y_slices[child].append(slice(start_, end))
y_slices_explicit[child].append(slice(start_, end))
# Increment start_
start_ = end
else:
end += self._get_variable_size(variable)

# Add to slices
y_slices[variable].append(slice(start, end))
y_slices_explicit[variable].append(slice(start, end))
# Add to bounds
lower_bounds.extend([variable.bounds[0]] * (end - start))
upper_bounds.extend([variable.bounds[1]] * (end - start))
Expand All @@ -322,8 +311,6 @@ def set_variable_slices(self, variables):

# Convert y_slices back to normal dictionary
self.y_slices = dict(y_slices)
# Also keep a record of what the y_slices are, to be stored in the model
self.y_slices_explicit = dict(y_slices_explicit)

# Also keep a record of bounds
self.bounds = (np.array(lower_bounds), np.array(upper_bounds))
Expand Down Expand Up @@ -572,86 +559,6 @@ def check_tab_conditions(self, symbol, bcs):

return bcs

def create_mass_matrix(self, model):
"""Creates mass matrix of the discretised model.
Note that the model is assumed to be of the form M*y_dot = f(t,y), where
M is the (possibly singular) mass matrix.
Parameters
----------
model : :class:`pybamm.BaseModel`
Discretised model. Must have attributes rhs, initial_conditions and
boundary_conditions (all dicts of {variable: equation})
Returns
-------
:class:`pybamm.Matrix`
The mass matrix
:class:`pybamm.Matrix`
The inverse of the ode part of the mass matrix (required by solvers
which only accept the ODEs in explicit form)
"""
# Create list of mass matrices for each equation to be put into block
# diagonal mass matrix for the model
mass_list = []
mass_inv_list = []

# get a list of model rhs variables that are sorted according to
# where they are in the state vector
model_variables = model.rhs.keys()
model_slices = []
for v in model_variables:
model_slices.append(self.y_slices[v][0])
sorted_model_variables = [
v for _, v in sorted(zip(model_slices, model_variables))
]

# Process mass matrices for the differential equations
for var in sorted_model_variables:
if var.domain == []:
# If variable domain empty then mass matrix is just 1
mass_list.append(1.0)
mass_inv_list.append(1.0)
else:
mass = (
self.spatial_methods[var.domain[0]]
.mass_matrix(var, self.bcs)
.entries
)
mass_list.append(mass)
if isinstance(
self.spatial_methods[var.domain[0]],
(pybamm.ZeroDimensionalSpatialMethod, pybamm.FiniteVolume),
):
# for 0D methods the mass matrix is just a scalar 1 and for
# finite volumes the mass matrix is identity, so no need to
# compute the inverse
mass_inv_list.append(mass)
else:
# inverse is more efficient in csc format
mass_inv = inv(csc_matrix(mass))
mass_inv_list.append(mass_inv)

# Create lumped mass matrix (of zeros) of the correct shape for the
# discretised algebraic equations
if model.algebraic.keys():
mass_algebraic_size = model.concatenated_algebraic.shape[0]
mass_algebraic = csr_matrix((mass_algebraic_size, mass_algebraic_size))
mass_list.append(mass_algebraic)

# Create block diagonal (sparse) mass matrix (if model is not empty)
# and inverse (if model has odes)
if len(model.rhs) + len(model.algebraic) > 0:
mass_matrix = pybamm.Matrix(block_diag(mass_list, format="csr"))
if len(model.rhs) > 0:
mass_matrix_inv = pybamm.Matrix(block_diag(mass_inv_list, format="csr"))
else:
mass_matrix_inv = None
else:
mass_matrix, mass_matrix_inv = None, None

return mass_matrix, mass_matrix_inv

def process_dict(self, var_eqn_dict):
"""Discretise a dictionary of {variable: equation}, broadcasting if necessary
(can be model.rhs, model.algebraic, model.initial_conditions or
Expand Down
32 changes: 19 additions & 13 deletions pybamm/models/base_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,16 @@ def variables(self):
def variable_names(self):
return list(self._variables.keys())

@property
@functools.cached_property
def variables_and_events(self):
"""
Returns variables and events in a single dictionary
"""
try:
return self._variables_and_events
except AttributeError:
self._variables_and_events = self.variables.copy()
self._variables_and_events.update(
{f"Event: {event.name}": event.expression for event in self.events}
)
return self._variables_and_events
variables_and_events = self.variables.copy()
variables_and_events.update(
{f"Event: {event.name}": event.expression for event in self.events}
)
return variables_and_events

@property
def events(self):
Expand Down Expand Up @@ -436,13 +433,14 @@ def __init__(
boundary_conditions=pybamm.ReadOnlyDict(boundary_conditions),
# Variables is initially empty, but will be filled in when variables are
# called
variables=_OnTheFlyUpdatedDict(self.variables_update_function),
variables=_OnTheFlyUpdatedDict(
unprocessed_variables, self.variables_update_function
),
events=events,
external_variables=external_variables,
timescale=timescale,
length_scales=length_scales,
)
self._unprocessed_variables = unprocessed_variables

@_BaseEquations.rhs.setter
def rhs(self, value):
Expand All @@ -454,11 +452,19 @@ class _OnTheFlyUpdatedDict(dict):
A dictionary that updates itself when a key is called.
"""

def __init__(self, variables_update_function):
def __init__(self, unprocessed_variables, variables_update_function):
super().__init__({})
self.unprocessed_variables = unprocessed_variables
self.variables_update_function = variables_update_function

def __getitem__(self, key):
if key not in self:
self.update(self.variables_update_function(key))
self.update(
{key: self.variables_update_function(self.unprocessed_variables[key])}
)
return super().__getitem__(key)

def copy(self):
return self.__class__(
self.unprocessed_variables, self.variables_update_function
)
41 changes: 27 additions & 14 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#
# Base model class
#
import numbers
import warnings
from collections import OrderedDict

import copy
import casadi
import numpy as np
Expand Down Expand Up @@ -120,15 +116,12 @@ def external_variables(self):
def external_variables(self, variables):
self._equations.external_variables = variables

def variable_names(self):
return list(self._equations._variables.keys())

@property
def variables_and_events(self):
"""
Returns variables and events in a single dictionary
"""
return self._equations._variables_and_events
return self._equations.variables_and_events

@property
def events(self):
Expand Down Expand Up @@ -221,19 +214,39 @@ def print_parameter_info(self):

@property
def bounds(self):
return self._equations.bounds
return self._equations._bounds

@property
def len_rhs(self):
return self._equations.len_rhs
return self._equations._len_rhs

@property
def len_algebraic(self):
return self._equations.len_algebraic
def len_alg(self):
return self._equations._len_alg

@property
def len_rhs_and_alg(self):
return self._equations.len_rhs_and_alg
return self._equations._len_rhs_and_alg

@property
def mass_matrix(self):
return self._equations._mass_matrix

@property
def mass_matrix_inv(self):
return self._equations._mass_matrix_inv

@property
def concatenated_rhs(self):
return self._equations._concatenated_rhs

@property
def concatenated_algebraic(self):
return self._equations._concatenated_algebraic

@property
def concatenated_initial_conditions(self):
return self._equations._concatenated_initial_conditions

def new_copy(self, equations=None):
"""
Expand Down Expand Up @@ -488,7 +501,7 @@ def export_casadi_objects(self, variable_names, input_parameter_order=None):
jac_algebraic = casadi.jacobian(algebraic, y_casadi)

# For specified variables, convert to casadi
variables = OrderedDict()
variables = {}
for name in variable_names:
var = self.variables[name]
variables[name] = var.to_casadi(t_casadi, y_casadi, inputs=ext_and_in)
Expand Down
Loading

0 comments on commit c9bda6c

Please sign in to comment.