Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More simplifications #2443

Merged
merged 13 commits into from
Nov 10, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

## Optimizations

- Added more rules for simplifying expressions, especially around Concatenations. Also, meshes constructed from multiple domains are now cached ([#2443](https://github.com/pybamm-team/PyBaMM/pull/2443))
- Added more rules for simplifying expressions. Constants in binary operators are now moved to the left by default (e.g. `x*2` returns `2*x`) ([#2424](https://github.com/pybamm-team/PyBaMM/pull/2424))

## Breaking changes
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/compare_comsol/compare_comsol_DFN.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_interp_fun(variable_name, domain):
comsol_x = comsol_variables["x"]

# Make sure to use dimensional space
pybamm_x = mesh.combine_submeshes(*domain).nodes * L_x
pybamm_x = mesh[domain].nodes * L_x
variable = interp.interp1d(comsol_x, variable, axis=0)(pybamm_x)

fun = pybamm.Interpolant(
Expand All @@ -88,7 +88,7 @@ def get_interp_fun(variable_name, domain):
)

fun.domains = {"primary": domain}
fun.mesh = mesh.combine_submeshes(*domain)
fun.mesh = mesh[domain]
fun.secondary_mesh = None
return fun

Expand Down
23 changes: 11 additions & 12 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,10 @@ def process_dict(self, var_eqn_dict):
for eqn_key, eqn in var_eqn_dict.items():
# Broadcast if the equation evaluates to a number (e.g. Scalar)
if np.prod(eqn.shape_for_testing) == 1 and not isinstance(eqn_key, str):
eqn = pybamm.FullBroadcast(eqn, broadcast_domains=eqn_key.domains)
if eqn_key.domain == []:
eqn = eqn * pybamm.Vector([1])
else:
eqn = pybamm.FullBroadcast(eqn, broadcast_domains=eqn_key.domains)

pybamm.logger.debug("Discretise {!r}".format(eqn_key))

Expand Down Expand Up @@ -784,14 +787,14 @@ def process_symbol(self, symbol):

# Assign mesh as an attribute to the processed variable
if symbol.domain != []:
discretised_symbol.mesh = self.mesh.combine_submeshes(*symbol.domain)
discretised_symbol.mesh = self.mesh[symbol.domain]
else:
discretised_symbol.mesh = None
# Assign secondary mesh
if symbol.domains["secondary"] != []:
discretised_symbol.secondary_mesh = self.mesh.combine_submeshes(
*symbol.domains["secondary"]
)
discretised_symbol.secondary_mesh = self.mesh[
symbol.domains["secondary"]
]
else:
discretised_symbol.secondary_mesh = None
return discretised_symbol
Expand Down Expand Up @@ -897,13 +900,9 @@ def _process_symbol(self, symbol):
elif isinstance(symbol, pybamm.Broadcast):
# Broadcast new_child to the domain specified by symbol.domain
# Different discretisations may broadcast differently
if symbol.domain == []:
out = disc_child * pybamm.Vector([1])
else:
out = spatial_method.broadcast(
disc_child, symbol.domains, symbol.broadcast_type
)
return out
return spatial_method.broadcast(
disc_child, symbol.domains, symbol.broadcast_type
)

elif isinstance(symbol, pybamm.DeltaFunction):
return spatial_method.delta_function(symbol, disc_child)
Expand Down
42 changes: 14 additions & 28 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,36 +144,22 @@ def x_average(symbol):
else: # pragma: no cover
# It should be impossible to get here
raise NotImplementedError
# If symbol is a concatenation of Broadcasts, its average value is the
# thickness-weighted average of the symbols being broadcasted
elif isinstance(symbol, pybamm.Concatenation) and all(
isinstance(child, pybamm.Broadcast) for child in symbol.children
# If symbol is a concatenation, its average value is the
# thickness-weighted average of the average of its children
elif isinstance(symbol, pybamm.Concatenation) and not isinstance(
symbol, pybamm.ConcatenationVariable
):
geo = pybamm.geometric_parameters
l_n = geo.n.l
l_s = geo.s.l
l_p = geo.p.l
if symbol.domain == ["negative electrode", "separator", "positive electrode"]:
a, b, c = [orp.orphans[0] for orp in symbol.orphans]
out = (l_n * a + l_s * b + l_p * c) / (l_n + l_s + l_p)
elif symbol.domain == ["separator", "positive electrode"]:
b, c = [orp.orphans[0] for orp in symbol.orphans]
out = (l_s * b + l_p * c) / (l_s + l_p)
# To respect domains we may need to broadcast the child back out
child = symbol.children[0]
# If symbol being returned doesn't have empty domain, return it
if out.domain != []:
return out
# Otherwise we may need to broadcast it
elif child.domains["secondary"] == []:
return out
else:
domain = child.domains["secondary"]
if child.domains["tertiary"] == []:
return pybamm.PrimaryBroadcast(out, domain)
else:
auxiliary_domains = {"secondary": child.domains["tertiary"]}
return pybamm.FullBroadcast(out, domain, auxiliary_domains)
ls = {
("negative electrode",): geo.n.l,
("separator",): geo.s.l,
("positive electrode",): geo.p.l,
("separator", "positive electrode"): geo.s.l + geo.p.l,
}
out = sum(
ls[tuple(orp.domain)] * x_average(orp) for orp in symbol.orphans
) / sum(ls[tuple(orp.domain)] for orp in symbol.orphans)
return out
# Average of a sum is sum of averages
elif isinstance(symbol, (pybamm.Addition, pybamm.Subtraction)):
return _sum_of_averages(symbol, x_average)
Expand Down
12 changes: 2 additions & 10 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,16 +711,8 @@ def _simplified_binary_broadcast_concatenation(left, right, operator):
return left._concatenation_new_copy(
[operator(child, right) for child in left.orphans]
)
elif (
isinstance(right, pybamm.Concatenation)
and not any(
isinstance(child, (pybamm.Variable, pybamm.StateVector))
for child in right.children
)
and (
all(child.is_constant() for child in left.children)
or all(child.is_constant() for child in right.children)
)
elif isinstance(right, pybamm.Concatenation) and not isinstance(
right, pybamm.ConcatenationVariable
):
return left._concatenation_new_copy(
[
Expand Down
9 changes: 8 additions & 1 deletion pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def check_and_set_domains(self, child, broadcast_domain):
# Can only do primary broadcast from current collector to electrode,
# particle-size or particle or from electrode to particle-size or particle.
# Note e.g. current collector to particle *is* allowed
if broadcast_domain == []:
raise pybamm.DomainError("Cannot Broadcast an object into empty domain.")
if child.domain == []:
pass
elif child.domain == ["current collector"] and not (
Expand Down Expand Up @@ -430,7 +432,10 @@ def __init__(

def check_and_set_domains(self, child, broadcast_domains):
"""See :meth:`Broadcast.check_and_set_domains`"""

if broadcast_domains["primary"] == []:
raise pybamm.DomainError(
"""Cannot do full broadcast to an empty primary domain"""
)
# Variables on the current collector can only be broadcast to 'primary'
if child.domain == ["current collector"]:
raise pybamm.DomainError(
Expand Down Expand Up @@ -544,6 +549,8 @@ def full_like(symbols, fill_value):
return array_type(entries, domains=sum_symbol.domains)

except NotImplementedError:
if sum_symbol.shape_for_testing == (1, 1):
return pybamm.Scalar(fill_value)
if sum_symbol.evaluates_on_edges("primary"):
return FullBroadcastToEdges(
fill_value, broadcast_domains=sum_symbol.domains
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _get_auxiliary_domain_repeats(self, auxiliary_domains):
mesh_pts = 1
for level, dom in auxiliary_domains.items():
if level != "primary" and dom != []:
mesh_pts *= self.full_mesh.combine_submeshes(*dom).npts
mesh_pts *= self.full_mesh[dom].npts
return mesh_pts

@property
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def __abs__(self):
elif isinstance(self, pybamm.Broadcast):
# Move absolute value inside the broadcast
# Apply recursively
abs_self_not_broad = pybamm.simplify_if_constant(abs(self.orphans[0]))
abs_self_not_broad = abs(self.orphans[0])
return self._unary_new_copy(abs_self_not_broad)
else:
k = pybamm.settings.abs_smoothing
Expand Down
21 changes: 19 additions & 2 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,9 @@ def __init__(self, children, initial_condition):
def _unary_new_copy(self, child):
return self.__class__(child, self.initial_condition)

def is_constant(self):
return False


class BoundaryGradient(BoundaryOperator):
"""
Expand Down Expand Up @@ -1084,7 +1087,10 @@ def grad(symbol):
"""
# Gradient of a broadcast is zero
if isinstance(symbol, pybamm.PrimaryBroadcast):
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
if symbol.child.domain == []:
new_child = pybamm.Scalar(0)
else:
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
return pybamm.PrimaryBroadcastToEdges(new_child, symbol.domain)
elif isinstance(symbol, pybamm.FullBroadcast):
return pybamm.FullBroadcastToEdges(0, broadcast_domains=symbol.domains)
Expand All @@ -1110,7 +1116,10 @@ def div(symbol):
"""
# Divergence of a broadcast is zero
if isinstance(symbol, pybamm.PrimaryBroadcastToEdges):
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
if symbol.child.domain == []:
new_child = pybamm.Scalar(0)
else:
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
return pybamm.PrimaryBroadcast(new_child, symbol.domain)
# Divergence commutes with Negate operator
if isinstance(symbol, pybamm.Negate):
Expand Down Expand Up @@ -1245,6 +1254,14 @@ def boundary_value(symbol, side):

def sign(symbol):
"""Returns a :class:`Sign` object."""
if isinstance(symbol, pybamm.Broadcast):
# Move sign inside the broadcast
# Apply recursively
return symbol._unary_new_copy(sign(symbol.orphans[0]))
elif isinstance(symbol, pybamm.Concatenation) and not isinstance(
symbol, pybamm.ConcatenationVariable
):
return pybamm.concatenation(*[sign(child) for child in symbol.orphans])
return pybamm.simplify_if_constant(Sign(symbol))


Expand Down
32 changes: 24 additions & 8 deletions pybamm/meshes/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,30 @@ def __init__(self, geometry, submesh_types, var_pts):
geometry[domain][spatial_variable][lim] = sym_eval

# Create submeshes
self.base_domains = []
for domain in geometry:
self[domain] = submesh_types[domain](geometry[domain], submesh_pts[domain])
self.base_domains.append(domain)

# add ghost meshes
self.add_ghost_meshes()

def __getitem__(self, domains):
if isinstance(domains, str):
domains = (domains,)
domains = tuple(domains)
try:
return super().__getitem__(domains)
except KeyError:
value = self.combine_submeshes(*domains)
self[domains] = value
return value

def __setitem__(self, domains, value):
if isinstance(domains, str):
domains = (domains,)
super().__setitem__(domains, value)

def combine_submeshes(self, *submeshnames):
"""Combine submeshes into a new submesh, using self.submeshclass
Raises pybamm.DomainError if submeshes to be combined do not match up (edges are
Expand All @@ -134,9 +152,6 @@ def combine_submeshes(self, *submeshnames):
"""
if submeshnames == ():
raise ValueError("Submesh domains being combined cannot be empty")
# If there is just a single submesh, we can return it directly
if len(submeshnames) == 1:
return self[submeshnames[0]]
# Check that the final edge of each submesh is the same as the first edge of the
# next submesh
for i in range(len(submeshnames) - 1):
Expand All @@ -159,7 +174,6 @@ def combine_submeshes(self, *submeshnames):
submesh.internal_boundaries = [
self[submeshname].edges[0] for submeshname in submeshnames[1:]
]

return submesh

def add_ghost_meshes(self):
Expand All @@ -172,22 +186,24 @@ def add_ghost_meshes(self):
submeshes = [
(domain, submesh)
for domain, submesh in self.items()
if not isinstance(submesh, (pybamm.SubMesh0D, pybamm.ScikitSubMesh2D))
if (
len(domain) == 1
and not isinstance(submesh, (pybamm.SubMesh0D, pybamm.ScikitSubMesh2D))
)
]
for domain, submesh in submeshes:

edges = submesh.edges

# left ghost cell: two edges, one node, to the left of existing submesh
lgs_edges = np.array([2 * edges[0] - edges[1], edges[0]])
self[domain + "_left ghost cell"] = pybamm.SubMesh1D(
self[domain[0] + "_left ghost cell"] = pybamm.SubMesh1D(
lgs_edges, submesh.coord_sys
)

# right ghost cell: two edges, one node, to the right of
# existing submesh
rgs_edges = np.array([edges[-1], 2 * edges[-1] - edges[-2]])
self[domain + "_right ghost cell"] = pybamm.SubMesh1D(
self[domain[0] + "_right ghost cell"] = pybamm.SubMesh1D(
rgs_edges, submesh.coord_sys
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@ def _get_standard_concentration_variables(self, c_e_dict):
electrolyte.
"""

c_e_typ = self.param.c_e_typ
c_e = pybamm.concatenation(*c_e_dict.values())
# Override print_name
c_e.print_name = "c_e"

variables = {
"Electrolyte concentration": c_e,
"X-averaged electrolyte concentration": pybamm.x_average(c_e),
}
variables = self._get_standard_domain_concentration_variables(c_e_dict)
variables.update(self._get_standard_whole_cell_concentration_variables(c_e))
return variables

def _get_standard_domain_concentration_variables(self, c_e_dict):
c_e_typ = self.param.c_e_typ
variables = {}
# Case where an electrode is not included (half-cell)
if "negative electrode" not in self.options.whole_cell_domains:
c_e_s = c_e_dict["separator"]
Expand Down Expand Up @@ -75,6 +76,24 @@ def _get_standard_concentration_variables(self, c_e_dict):

return variables

def _get_standard_whole_cell_concentration_variables(self, c_e):
c_e_typ = self.param.c_e_typ

variables = {
"Electrolyte concentration": c_e,
"X-averaged electrolyte concentration": pybamm.x_average(c_e),
}
variables_nondim = variables.copy()
for name, var in variables_nondim.items():
variables.update(
{
f"{name} [mol.m-3]": c_e_typ * var,
f"{name} [Molar]": c_e_typ * var / 1000,
}
)

return variables

def _get_standard_porosity_times_concentration_variables(self, eps_c_e_dict):
eps_c_e = pybamm.concatenation(*eps_c_e_dict.values())
variables = {"Porosity times concentration": eps_c_e}
Expand Down
Loading