Skip to content

Commit

Permalink
compiler: Add cluster-level temp
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 16, 2023
1 parent 56a7ddc commit af96d7a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
5 changes: 1 addition & 4 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
self._input_comm = (input_comm or MPI.COMM_WORLD).Clone()

# if len(shape) == 3:
# topology = ('*', '*', 1)

# topology = ('*', '*', 1)
# topology = ('*', '*', 8)
# topology = ('*', '*', 4)

if topology is None:
# `MPI.Compute_dims` sets the dimension sizes to be as close to each other
Expand Down
15 changes: 10 additions & 5 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,33 @@
from functools import singledispatch

from sympy import Add, Function, Indexed, Mul, Pow
from sympy.core.core import ordering_of_classes

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.types import Eq, Temp as Temp0
from devito.types import Eq, Temp

__all__ = ['cse']


class Temp(Temp0):
pass
class CTemp(Temp):

"""
A cluster-level Temp, similar to Temp, ensured to have different priority
"""
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')


@cluster_pass
def cse(cluster, sregistry, options, *args):
"""
Common sub-expressions elimination (CSE).
"""
make = lambda: Temp(name=sregistry.make_name(), dtype=cluster.dtype)
make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype)
exprs = _cse(cluster, make, min_cost=options['cse-min-cost'])

return cluster.rebuild(exprs=exprs)
Expand Down Expand Up @@ -130,7 +135,7 @@ def _compact_temporaries(exprs, exclude):
# safely be compacted; a generic Symbol could instead be accessed in a subsequent
# Cluster, for example: `for (i = ...) { a = b; for (j = a ...) ...`
mapper = {e.lhs: e.rhs for e in exprs
if isinstance(e.lhs, Temp) and q_leaf(e.rhs) and e.lhs not in exclude}
if isinstance(e.lhs, CTemp) and q_leaf(e.rhs) and e.lhs not in exclude}

processed = []
for e in exprs:
Expand Down

0 comments on commit af96d7a

Please sign in to comment.