diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index ae109a7389..d745096b67 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -202,10 +202,7 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None): self._input_comm = (input_comm or MPI.COMM_WORLD).Clone() # if len(shape) == 3: - # topology = ('*', '*', 1) - - # topology = ('*', '*', 1) - # topology = ('*', '*', 8) + # topology = ('*', '*', 4) if topology is None: # `MPI.Compute_dims` sets the dimension sizes to be as close to each other diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 5f00a0341e..ddff97f7c6 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -2,6 +2,7 @@ from functools import singledispatch from sympy import Add, Function, Indexed, Mul, Pow +from sympy.core.core import ordering_of_classes from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass @@ -9,13 +10,17 @@ from devito.symbolics import estimate_cost, q_leaf from devito.symbolics.manipulation import _uxreplace from devito.tools import as_list -from devito.types import Eq, Temp as Temp0 +from devito.types import Eq, Temp __all__ = ['cse'] -class Temp(Temp0): - pass +class CTemp(Temp): + + """ + A cluster-level Temp, similar to Temp, ensured to have different priority + """ + ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp') @cluster_pass @@ -23,7 +28,7 @@ def cse(cluster, sregistry, options, *args): """ Common sub-expressions elimination (CSE). """ - make = lambda: Temp(name=sregistry.make_name(), dtype=cluster.dtype) + make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype) exprs = _cse(cluster, make, min_cost=options['cse-min-cost']) return cluster.rebuild(exprs=exprs) @@ -130,7 +135,7 @@ def _compact_temporaries(exprs, exclude): # safely be compacted; a generic Symbol could instead be accessed in a subsequent # Cluster, for example: `for (i = ...) { a = b; for (j = a ...) ...` mapper = {e.lhs: e.rhs for e in exprs - if isinstance(e.lhs, Temp) and q_leaf(e.rhs) and e.lhs not in exclude} + if isinstance(e.lhs, CTemp) and q_leaf(e.rhs) and e.lhs not in exclude} processed = [] for e in exprs: