Skip to content

Commit

Permalink
Merge pull request spcl#167 from spcl/simplify-ext
Browse files Browse the repository at this point in the history
MapFusion fixes and better simplification support
  • Loading branch information
tbennun authored Mar 11, 2020
2 parents a927bd6 + 9d3b855 commit 448553d
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 45 deletions.
98 changes: 92 additions & 6 deletions dace/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo:
class Scope(object):
""" A class defining a scope, its parent and children scopes, variables, and
scope entry/exit nodes. """
def __init__(self, entrynode, exitnode):
self.parent = None
self.children = []
self.defined_vars = []
self.entry = entrynode
self.exit = exitnode
def __init__(self, entrynode: nd.EntryNode, exitnode: nd.ExitNode):
self.parent: 'Scope' = None
self.children: List['Scope'] = []
self.defined_vars: List[str] = []
self.entry: nd.EntryNode = entrynode
self.exit: nd.ExitNode = exitnode


class InvalidSDFGError(Exception):
Expand Down Expand Up @@ -4580,3 +4580,89 @@ class skips the process.
warnings.warn('Optimizer interface class "%s" not found' % clazz)

return result


def consolidate_edges_scope(state: SDFGState,
scope_node: Union[nd.EntryNode, nd.ExitNode]
) -> int:
"""
Union scope-entering memlets relating to the same data node in a scope.
This effectively reduces the number of connectors and allows more
transformations to be performed, at the cost of losing the individual
per-tasklet memlets.
:param state: The SDFG state in which the scope to consolidate resides.
:param scope_node: The scope node whose edges will be consolidated.
:return: Number of edges removed.
"""
if scope_node is None:
return 0
data_to_conn = {}
consolidated = 0
if isinstance(scope_node, nd.EntryNode):
outer_edges = state.in_edges
inner_edges = state.out_edges
remove_outer_connector = scope_node.remove_in_connector
remove_inner_connector = scope_node.remove_out_connector
prefix, oprefix = 'IN_', 'OUT_'
else:
outer_edges = state.out_edges
inner_edges = state.in_edges
remove_outer_connector = scope_node.remove_out_connector
remove_inner_connector = scope_node.remove_in_connector
prefix, oprefix = 'OUT_', 'IN_'

edges_by_connector = collections.defaultdict(list)
connectors_to_remove = set()
for e in inner_edges(scope_node):
edges_by_connector[e.src_conn].append(e)
if e.data.data not in data_to_conn:
data_to_conn[e.data.data] = e.src_conn
elif data_to_conn[e.data.data] != e.src_conn: # Need to consolidate
connectors_to_remove.add(e.src_conn)

for conn in connectors_to_remove:
e = edges_by_connector[conn][0]
# Outer side of the scope - remove edge and union subsets
target_conn = prefix + data_to_conn[e.data.data][len(oprefix):]
conn_to_remove = prefix + conn[len(oprefix):]
remove_outer_connector(conn_to_remove)
out_edge = next(ed for ed in outer_edges(scope_node)
if ed.dst_conn == target_conn)
edge_to_remove = next(ed for ed in outer_edges(scope_node)
if ed.dst_conn == conn_to_remove)
out_edge.data.subset = sbs.union(out_edge.data.subset,
edge_to_remove.data.subset)
state.remove_edge(edge_to_remove)
consolidated += 1
# Inner side of the scope - remove and reconnect
remove_inner_connector(e.src_conn)
for e in edges_by_connector[conn]:
e._src_conn = data_to_conn[e.data.data]

return consolidated


def consolidate_edges(sdfg: SDFG) -> int:
"""
Union scope-entering memlets relating to the same data node in all states.
This effectively reduces the number of connectors and allows more
transformations to be performed, at the cost of losing the individual
per-tasklet memlets.
:param sdfg: The SDFG to consolidate.
:return: Number of edges removed.
"""
consolidated = 0
for state in sdfg.nodes():
# Start bottom-up
queue = state.scope_leaves()
next_queue = []
while len(queue) > 0:
for scope in queue:
consolidated += consolidate_edges_scope(state, scope.entry)
consolidated += consolidate_edges_scope(state, scope.exit)
if scope.parent is not None:
next_queue.append(scope.parent)
queue = next_queue
next_queue = []

return consolidated
48 changes: 41 additions & 7 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,21 @@ class Subset(object):
def covers(self, other):
""" Returns True if this subset covers (using a bounding box) another
subset. """
def nng(expr):
# When dealing with set sizes, assume symbols are non-negative
# TODO: Fix in symbol definition, not here
for sym in list(expr.free_symbols):
expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)})
return expr

try:
return all([
rb <= orb and re >= ore for rb, re, orb, ore in zip(
self.min_element(), self.max_element(),
other.min_element(), other.max_element())
])
return all([(symbolic.simplify_ext(nng(rb)) <=
symbolic.simplify_ext(nng(orb))) == True
and (symbolic.simplify_ext(nng(re)) >=
symbolic.simplify_ext(nng(ore))) == True
for rb, re, orb, ore in zip(
self.min_element(), self.max_element_approx(),
other.min_element(), other.max_element_approx())])
except TypeError:
return False

Expand Down Expand Up @@ -68,6 +77,12 @@ def _expr(val):
return val


def _approx(val):
if isinstance(val, symbolic.SymExpr):
return val.approx
return val


def _tuple_to_symexpr(val):
return (symbolic.SymExpr(val[0], val[1])
if isinstance(val, tuple) else symbolic.pystr_to_symbolic(val))
Expand Down Expand Up @@ -202,8 +217,9 @@ def min_element(self):

def max_element(self):
return [_expr(x[1]) for x in self.ranges]
# return [(sp.floor((iMax - iMin) / step) - 1) * step
# for iMin, iMax, step in self.ranges]

def max_element_approx(self):
return [_approx(x[1]) for x in self.ranges]

def coord_at(self, i):
""" Returns the offseted coordinates of this subset at
Expand Down Expand Up @@ -553,6 +569,16 @@ def pop(self, dimensions):
def string_list(self):
return Range.ndslice_to_string_list(self.ranges, self.tile_sizes)

def replace(self, repl_dict):
for i, ((rb, re, rs),
ts) in enumerate(zip(self.ranges, self.tile_sizes)):
self.ranges[i] = (
rb.subs(repl_dict) if symbolic.issymbolic(rb) else rb,
re.subs(repl_dict) if symbolic.issymbolic(re) else re,
rs.subs(repl_dict) if symbolic.issymbolic(rs) else rs)
self.tile_sizes[i] = (ts.subs(repl_dict)
if symbolic.issymbolic(ts) else ts)


@dace.serialize.serializable
class Indices(Subset):
Expand Down Expand Up @@ -611,6 +637,9 @@ def min_element(self):
def max_element(self):
return self.indices

def max_element_approx(self):
return [_approx(ind) for ind in self.indices]

def data_dims(self):
return 0

Expand Down Expand Up @@ -726,6 +755,11 @@ def unsqueeze(self, axes):
for axis in sorted(axes):
self.indices.insert(axis, 0)

def replace(self, repl_dict):
for i, ind in enumerate(self.indices):
self.indices[i] = (ind.subs(repl_dict)
if symbolic.issymbolic(ind) else ind)


def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
""" Perform union by creating a bounding-box of two subsets. """
Expand Down
21 changes: 21 additions & 0 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,27 @@ def sympy_divide_fix(expr):
return nexpr


def simplify_ext(expr):
"""
An extended version of simplification with expression fixes for sympy.
:param expr: A sympy expression.
:return: Simplified version of the expression.
"""
a = sympy.Wild('a')
b = sympy.Wild('b')
c = sympy.Wild('c')

# Push expressions into both sides of min/max.
# Example: Min(N, 4) + 1 => Min(N + 1, 5)
dic = expr.match(sympy.Min(a, b) + c)
if dic:
return sympy.Min(dic[a] + dic[c], dic[b] + dic[c])
dic = expr.match(sympy.Max(a, b) + c)
if dic:
return sympy.Max(dic[a] + dic[c], dic[b] + dic[c])
return expr


def pystr_to_symbolic(expr, symbol_map=None, simplify=None):
""" Takes a Python string and converts it into a symbolic expression. """
if isinstance(expr, SymExpr):
Expand Down
80 changes: 48 additions & 32 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"""

from copy import deepcopy as dcpy
from dace import dtypes, registry, symbolic
from dace import dtypes, registry, symbolic, subsets
from dace.graph import nodes, nxutil
from dace.memlet import Memlet
from dace.sdfg import replace
from dace.transformation import pattern_matching
from typing import List, Union
Expand Down Expand Up @@ -122,6 +123,15 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
if perm is None:
return False

# Check if any intermediate transient is also going to another location
second_inodes = set(e.src for e in graph.in_edges(second_map_entry)
if isinstance(e.src, nodes.AccessNode))
transients_to_remove = intermediate_nodes & second_inodes
# if any(e.dst != second_map_entry for n in transients_to_remove
# for e in graph.out_edges(n)):
if any(graph.out_degree(n) > 1 for n in transients_to_remove):
return False

# Create a dict that maps parameters of the first map to those of the
# second map.
params_dict = {}
Expand All @@ -147,28 +157,20 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
continue

provided = False

# Compute second subset with respect to first subset's symbols
sbs_permuted = dcpy(second_memlet.subset)
sbs_permuted.replace({
symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v)
for k, v in params_dict.items()
})

for first_memlet in out_memlets:
if first_memlet.data != second_memlet.data:
continue
# If there is an equivalent subset, it is provided
expected_second_subset = []
for _tup in first_memlet.subset:
new_tuple = []
if isinstance(_tup, symbolic.symbol):
new_tuple = symbolic.symbol(params_dict[str(_tup)])
elif isinstance(_tup, (list, tuple)):
for _sym in _tup:
if (isinstance(_sym, symbolic.symbol)
and str(_sym) in params_dict):
new_tuple.append(
symbolic.symbol(params_dict[str(_sym)]))
else:
new_tuple.append(_sym)
new_tuple = tuple(new_tuple)
else:
new_tuple = _tup
expected_second_subset.append(new_tuple)
if expected_second_subset == list(second_memlet.subset):

# If there is a covered subset, it is provided
if first_memlet.subset.covers(sbs_permuted):
provided = True
break

Expand Down Expand Up @@ -272,22 +274,20 @@ def apply(self, sdfg):
# In this transformation, there can only be one edge to the
# second map
assert len(out_edges) == 1

# Get source connector to the second map
connector = out_edges[0].dst_conn[3:]

new_dst = None
new_dst_conn = None
new_dsts = []
# Look at the second map entry out-edges to get the new
# destination
for _e in graph.out_edges(second_entry):
if _e.src_conn[4:] == connector:
new_dst = _e.dst
new_dst_conn = _e.dst_conn
break
if new_dst is None:
# Access node is not used in the second map
# destinations
for e in graph.out_edges(second_entry):
if e.src_conn[4:] == connector:
new_dsts.append(e)
if not new_dsts: # Access node is not used in the second map
nodes_to_remove.add(access_node)
continue

# If the source is an access node, modify the memlet to point
# to it
if (isinstance(edge.src, nodes.AccessNode)
Expand All @@ -299,7 +299,8 @@ def apply(self, sdfg):

else:
# Add a transient scalar/array
self.fuse_nodes(sdfg, graph, edge, new_dst, new_dst_conn)
self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst,
new_dsts[0].dst_conn, new_dsts[1:])

edges_to_remove.add(edge)

Expand Down Expand Up @@ -378,8 +379,15 @@ def apply(self, sdfg):
# Fix scope exit to point to the right map
second_exit.map = first_entry.map

def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn):
def fuse_nodes(self,
sdfg,
graph,
edge,
new_dst,
new_dst_conn,
other_edges=None):
""" Fuses two nodes via memlets and possibly transient arrays. """
other_edges = other_edges or []
memlet_path = graph.memlet_path(edge)
access_node = memlet_path[-1].dst

Expand Down Expand Up @@ -407,6 +415,10 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn):
# Add edge that leads to the second node
graph.add_edge(local_node, src_connector, new_dst, new_dst_conn,
dcpy(edge.data))

for e in other_edges:
graph.add_edge(local_node, src_connector, e.dst, e.dst_conn,
dcpy(edge.data))
else:
sdfg.add_transient(local_name,
edge.data.subset.size(),
Expand All @@ -430,6 +442,10 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn):
graph.add_edge(local_node, src_connector, new_dst, new_dst_conn,
dcpy(edge.data))

for e in other_edges:
graph.add_edge(local_node, src_connector, e.dst, e.dst_conn,
dcpy(edge.data))

# Modify data and memlets on all surrounding edges to match array
for neighbor in graph.all_edges(local_node):
for e in graph.memlet_tree(neighbor):
Expand Down

0 comments on commit 448553d

Please sign in to comment.