From a9c5dc9c483f0a94b082dd235366fcfe0c2a8620 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 11:17:32 +0100 Subject: [PATCH 1/8] Harden MapFusion conditions --- dace/transformation/dataflow/map_fusion.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 1969538ce5..8356c0da42 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -2,8 +2,9 @@ """ from copy import deepcopy as dcpy -from dace import dtypes, registry, symbolic +from dace import dtypes, registry, symbolic, subsets from dace.graph import nodes, nxutil +from dace.memlet import Memlet from dace.sdfg import replace from dace.transformation import pattern_matching from typing import List, Union @@ -122,6 +123,15 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): if perm is None: return False + # Check if any intermediate transient is also going to another location + second_inodes = set(e.src for e in graph.in_edges(second_map_entry) + if isinstance(e.src, nodes.AccessNode)) + transients_to_remove = intermediate_nodes & second_inodes + # if any(e.dst != second_map_entry for n in transients_to_remove + # for e in graph.out_edges(n)): + if any(graph.out_degree(n) > 1 for n in transients_to_remove): + return False + # Create a dict that maps parameters of the first map to those of the # second map. params_dict = {} From 6d12dcda61379142c91cf9fe25b1e0e080eb5a82 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 11:17:53 +0100 Subject: [PATCH 2/8] Types for Scopes --- dace/sdfg.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dace/sdfg.py b/dace/sdfg.py index 96f815f555..43ca1c4e1e 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -59,12 +59,12 @@ def getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: class Scope(object): """ A class defining a scope, its parent and children scopes, variables, and scope entry/exit nodes. """ - def __init__(self, entrynode, exitnode): - self.parent = None - self.children = [] - self.defined_vars = [] - self.entry = entrynode - self.exit = exitnode + def __init__(self, entrynode: nd.EntryNode, exitnode: nd.ExitNode): + self.parent: 'Scope' = None + self.children: List['Scope'] = [] + self.defined_vars: List[str] = [] + self.entry: nd.EntryNode = entrynode + self.exit: nd.ExitNode = exitnode class InvalidSDFGError(Exception): From b0a899a935bebee7ea45c87794fc1c1da060165e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 12:04:17 +0100 Subject: [PATCH 3/8] Function to consolidate edges with the same memlets but different connectors --- dace/sdfg.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/dace/sdfg.py b/dace/sdfg.py index 43ca1c4e1e..228f7f864a 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -4580,3 +4580,89 @@ class skips the process. warnings.warn('Optimizer interface class "%s" not found' % clazz) return result + + +def consolidate_edges_scope(state: SDFGState, + scope_node: Union[nd.EntryNode, nd.ExitNode] + ) -> int: + """ + Union scope-entering memlets relating to the same data node in a scope. + This effectively reduces the number of connectors and allows more + transformations to be performed, at the cost of losing the individual + per-tasklet memlets. + :param state: The SDFG state in which the scope to consolidate resides. + :param scope_node: The scope node whose edges will be consolidated. + :return: Number of edges removed. + """ + if scope_node is None: + return 0 + data_to_conn = {} + consolidated = 0 + if isinstance(scope_node, nd.EntryNode): + outer_edges = state.in_edges + inner_edges = state.out_edges + remove_outer_connector = scope_node.remove_in_connector + remove_inner_connector = scope_node.remove_out_connector + prefix, oprefix = 'IN_', 'OUT_' + else: + outer_edges = state.out_edges + inner_edges = state.in_edges + remove_outer_connector = scope_node.remove_out_connector + remove_inner_connector = scope_node.remove_in_connector + prefix, oprefix = 'OUT_', 'IN_' + + edges_by_connector = collections.defaultdict(list) + connectors_to_remove = set() + for e in inner_edges(scope_node): + edges_by_connector[e.src_conn].append(e) + if e.data.data not in data_to_conn: + data_to_conn[e.data.data] = e.src_conn + elif data_to_conn[e.data.data] != e.src_conn: # Need to consolidate + connectors_to_remove.add(e.src_conn) + + for conn in connectors_to_remove: + e = edges_by_connector[conn][0] + # Outer side of the scope - remove edge and union subsets + target_conn = prefix + data_to_conn[e.data.data][len(oprefix):] + conn_to_remove = prefix + conn[len(oprefix):] + remove_outer_connector(conn_to_remove) + out_edge = next(ed for ed in outer_edges(scope_node) + if ed.dst_conn == target_conn) + edge_to_remove = next(ed for ed in outer_edges(scope_node) + if ed.dst_conn == conn_to_remove) + out_edge.data.subset = sbs.union(out_edge.data.subset, + edge_to_remove.data.subset) + state.remove_edge(edge_to_remove) + consolidated += 1 + # Inner side of the scope - remove and reconnect + remove_inner_connector(e.src_conn) + for e in edges_by_connector[conn]: + e._src_conn = data_to_conn[e.data.data] + + return consolidated + + +def consolidate_edges(sdfg: SDFG) -> int: + """ + Union scope-entering memlets relating to the same data node in all states. + This effectively reduces the number of connectors and allows more + transformations to be performed, at the cost of losing the individual + per-tasklet memlets. + :param sdfg: The SDFG to consolidate. + :return: Number of edges removed. + """ + consolidated = 0 + for state in sdfg.nodes(): + # Start bottom-up + queue = state.scope_leaves() + next_queue = [] + while len(queue) > 0: + for scope in queue: + consolidated += consolidate_edges_scope(state, scope.entry) + consolidated += consolidate_edges_scope(state, scope.exit) + if scope.parent is not None: + next_queue.append(scope.parent) + queue = next_queue + next_queue = [] + + return consolidated From 0ea3dccb5ec95b90de1db995ec48333d623a24a7 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 12:08:07 +0100 Subject: [PATCH 4/8] Add extended sympy simplification with respect to Min/Max functions --- dace/subsets.py | 30 +++++++++++++++++++++++------- dace/symbolic.py | 21 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 340fe4aa20..a33f639ce8 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -12,12 +12,21 @@ class Subset(object): def covers(self, other): """ Returns True if this subset covers (using a bounding box) another subset. """ + def nng(expr): + # When dealing with set sizes, assume symbols are non-negative + # TODO: Fix in symbol definition, not here + for sym in list(expr.free_symbols): + expr = expr.subs({sym: sp.Symbol(sym.name, nonnegative=True)}) + return expr + try: - return all([ - rb <= orb and re >= ore for rb, re, orb, ore in zip( - self.min_element(), self.max_element(), - other.min_element(), other.max_element()) - ]) + return all([(symbolic.simplify_ext(nng(rb)) <= + symbolic.simplify_ext(nng(orb))) == True + and (symbolic.simplify_ext(nng(re)) >= + symbolic.simplify_ext(nng(ore))) == True + for rb, re, orb, ore in zip( + self.min_element(), self.max_element_approx(), + other.min_element(), other.max_element_approx())]) except TypeError: return False @@ -68,6 +77,12 @@ def _expr(val): return val +def _approx(val): + if isinstance(val, symbolic.SymExpr): + return val.approx + return val + + def _tuple_to_symexpr(val): return (symbolic.SymExpr(val[0], val[1]) if isinstance(val, tuple) else symbolic.pystr_to_symbolic(val)) @@ -202,8 +217,9 @@ def min_element(self): def max_element(self): return [_expr(x[1]) for x in self.ranges] - # return [(sp.floor((iMax - iMin) / step) - 1) * step - # for iMin, iMax, step in self.ranges] + + def max_element_approx(self): + return [_approx(x[1]) for x in self.ranges] def coord_at(self, i): """ Returns the offseted coordinates of this subset at diff --git a/dace/symbolic.py b/dace/symbolic.py index f0af5b6430..bede952094 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -558,6 +558,27 @@ def sympy_divide_fix(expr): return nexpr +def simplify_ext(expr): + """ + An extended version of simplification with expression fixes for sympy. + :param expr: A sympy expression. + :return: Simplified version of the expression. + """ + a = sympy.Wild('a') + b = sympy.Wild('b') + c = sympy.Wild('c') + + # Push expressions into both sides of min/max. + # Example: Min(N, 4) + 1 => Min(N + 1, 5) + dic = expr.match(sympy.Min(a, b) + c) + if dic: + return sympy.Min(dic[a] + dic[c], dic[b] + dic[c]) + dic = expr.match(sympy.Max(a, b) + c) + if dic: + return sympy.Max(dic[a] + dic[c], dic[b] + dic[c]) + return expr + + def pystr_to_symbolic(expr, symbol_map=None, simplify=None): """ Takes a Python string and converts it into a symbolic expression. """ if isinstance(expr, SymExpr): From 750db752b4a5be8c4179c1121ca3fb14d1f52862 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 12:08:43 +0100 Subject: [PATCH 5/8] Support multiple destination edges in second map --- dace/transformation/dataflow/map_fusion.py | 38 +++++++++++++++------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 8356c0da42..9f7ca8d2f4 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -282,22 +282,20 @@ def apply(self, sdfg): # In this transformation, there can only be one edge to the # second map assert len(out_edges) == 1 + # Get source connector to the second map connector = out_edges[0].dst_conn[3:] - new_dst = None - new_dst_conn = None + new_dsts = [] # Look at the second map entry out-edges to get the new - # destination - for _e in graph.out_edges(second_entry): - if _e.src_conn[4:] == connector: - new_dst = _e.dst - new_dst_conn = _e.dst_conn - break - if new_dst is None: - # Access node is not used in the second map + # destinations + for e in graph.out_edges(second_entry): + if e.src_conn[4:] == connector: + new_dsts.append(e) + if not new_dsts: # Access node is not used in the second map nodes_to_remove.add(access_node) continue + # If the source is an access node, modify the memlet to point # to it if (isinstance(edge.src, nodes.AccessNode) @@ -309,7 +307,8 @@ def apply(self, sdfg): else: # Add a transient scalar/array - self.fuse_nodes(sdfg, graph, edge, new_dst, new_dst_conn) + self.fuse_nodes(sdfg, graph, edge, new_dsts[0].dst, + new_dsts[0].dst_conn, new_dsts[1:]) edges_to_remove.add(edge) @@ -388,8 +387,15 @@ def apply(self, sdfg): # Fix scope exit to point to the right map second_exit.map = first_entry.map - def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn): + def fuse_nodes(self, + sdfg, + graph, + edge, + new_dst, + new_dst_conn, + other_edges=None): """ Fuses two nodes via memlets and possibly transient arrays. """ + other_edges = other_edges or [] memlet_path = graph.memlet_path(edge) access_node = memlet_path[-1].dst @@ -417,6 +423,10 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn): # Add edge that leads to the second node graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) + + for e in other_edges: + graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, + dcpy(edge.data)) else: sdfg.add_transient(local_name, edge.data.subset.size(), @@ -440,6 +450,10 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn): graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) + for e in other_edges: + graph.add_edge(local_node, src_connector, e.dst, e.dst_conn, + dcpy(edge.data)) + # Modify data and memlets on all surrounding edges to match array for neighbor in graph.all_edges(local_node): for e in graph.memlet_tree(neighbor): From 690537e8bef4af48b87a78d0e11c201ddd4a0c64 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 12:28:57 +0100 Subject: [PATCH 6/8] add missing method --- dace/subsets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dace/subsets.py b/dace/subsets.py index a33f639ce8..f2d3452aa2 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -627,6 +627,9 @@ def min_element(self): def max_element(self): return self.indices + def max_element_approx(self): + return [_approx(ind) for ind in self.indices] + def data_dims(self): return 0 From f7582fe9f16556e7c8d406c44ffff1a8373ca398 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 12:36:33 +0100 Subject: [PATCH 7/8] new formulation for MapFusion using subset cover --- dace/subsets.py | 11 ++++++++ dace/transformation/dataflow/map_fusion.py | 30 ++++++++-------------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index f2d3452aa2..969fd47767 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -569,6 +569,13 @@ def pop(self, dimensions): def string_list(self): return Range.ndslice_to_string_list(self.ranges, self.tile_sizes) + def replace(self, repl_dict): + for i, ((rb, re, rs), + ts) in enumerate(zip(self.ranges, self.tile_sizes)): + self.ranges[i] = (rb.subs(repl_dict), re.subs(repl_dict), + rs.subs(repl_dict)) + self.tile_sizes[i] = ts.subs(repl_dict) + @dace.serialize.serializable class Indices(Subset): @@ -745,6 +752,10 @@ def unsqueeze(self, axes): for axis in sorted(axes): self.indices.insert(axis, 0) + def replace(self, repl_dict): + for i, ind in enumerate(self.indices): + self.indices[i] = ind.subs(repl_dict) + def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: """ Perform union by creating a bounding-box of two subsets. """ diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 9f7ca8d2f4..37dd142c84 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -157,28 +157,20 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): continue provided = False + + # Compute second subset with respect to first subset's symbols + sbs_permuted = dcpy(second_memlet.subset) + sbs_permuted.replace({ + symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) + for k, v in params_dict.items() + }) + for first_memlet in out_memlets: if first_memlet.data != second_memlet.data: continue - # If there is an equivalent subset, it is provided - expected_second_subset = [] - for _tup in first_memlet.subset: - new_tuple = [] - if isinstance(_tup, symbolic.symbol): - new_tuple = symbolic.symbol(params_dict[str(_tup)]) - elif isinstance(_tup, (list, tuple)): - for _sym in _tup: - if (isinstance(_sym, symbolic.symbol) - and str(_sym) in params_dict): - new_tuple.append( - symbolic.symbol(params_dict[str(_sym)])) - else: - new_tuple.append(_sym) - new_tuple = tuple(new_tuple) - else: - new_tuple = _tup - expected_second_subset.append(new_tuple) - if expected_second_subset == list(second_memlet.subset): + + # If there is a covered subset, it is provided + if first_memlet.subset.covers(sbs_permuted): provided = True break From 7998e285d372818e9a0a343a28c89a56d2e1507c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 10 Mar 2020 23:35:51 +0100 Subject: [PATCH 8/8] Fix minor issue with subset.replace --- dace/subsets.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 969fd47767..46a6ab16b5 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -572,9 +572,12 @@ def string_list(self): def replace(self, repl_dict): for i, ((rb, re, rs), ts) in enumerate(zip(self.ranges, self.tile_sizes)): - self.ranges[i] = (rb.subs(repl_dict), re.subs(repl_dict), - rs.subs(repl_dict)) - self.tile_sizes[i] = ts.subs(repl_dict) + self.ranges[i] = ( + rb.subs(repl_dict) if symbolic.issymbolic(rb) else rb, + re.subs(repl_dict) if symbolic.issymbolic(re) else re, + rs.subs(repl_dict) if symbolic.issymbolic(rs) else rs) + self.tile_sizes[i] = (ts.subs(repl_dict) + if symbolic.issymbolic(ts) else ts) @dace.serialize.serializable @@ -754,7 +757,8 @@ def unsqueeze(self, axes): def replace(self, repl_dict): for i, ind in enumerate(self.indices): - self.indices[i] = ind.subs(repl_dict) + self.indices[i] = (ind.subs(repl_dict) + if symbolic.issymbolic(ind) else ind) def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: