From 57edcbaf043e9408c53aa07018eabdc5c3207404 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 17 Oct 2024 16:48:01 +0200 Subject: [PATCH 01/16] Fix the issue with cpp codegen, where it currently cannot handle inputs like: ```c++ cpp.reshape_strides(Range([(0, 4, 1), (0, 5, 1)]), None, None, [2, 3, 5]) ``` and crashes with an index error. --- dace/codegen/targets/cpp.py | 1 + tests/codegen/targets/cpp_test.py | 75 +++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 tests/codegen/targets/cpp_test.py diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index c34c829c31..b42e310655 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -417,6 +417,7 @@ def reshape_strides(subset, strides, original_strides, copy_shape): dims = len(copy_shape) reduced_tile_sizes = [ts for ts, s in zip(subset.tile_sizes, original_copy_shape) if s != 1] + reduced_tile_sizes += [1] * (dims - len(reduced_tile_sizes)) # Pad the remainder with 1s to maintain dimensions. reshaped_copy = copy_shape + [ts for ts in subset.tile_sizes if ts != 1] reshaped_copy[:len(copy_shape)] = [s / ts for s, ts in zip(copy_shape, reduced_tile_sizes)] diff --git a/tests/codegen/targets/cpp_test.py b/tests/codegen/targets/cpp_test.py new file mode 100644 index 0000000000..4be5fcf4ff --- /dev/null +++ b/tests/codegen/targets/cpp_test.py @@ -0,0 +1,75 @@ +import unittest +from functools import reduce +from operator import mul + +from dace.codegen.targets import cpp +from dace.subsets import Range + + +class ReshapeStrides(unittest.TestCase): + def test_multidim_array_all_dims_unit(self): + r = Range([(0, 0, 1), (0, 0, 1)]) + + # To smaller-sized shape + target_dims = [1] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, [1]) + + # To equal-sized shape + target_dims = [1, 1] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, [1, 1]) + + # To larger-sized shape + target_dims = [1, 1, 1] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, [1, 1, 1]) + + def test_multidim_array_some_dims_unit(self): + r = Range([(0, 1, 1), (0, 0, 1)]) + + # To smaller-sized shape + target_dims = [2] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + + # To equal-sized shape + target_dims = [2, 1] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + + # To equal-sized shape + target_dims = [2, 1, 1] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + + def test_multidim_array_different_shape(self): + r = Range([(0, 4, 1), (0, 5, 1)]) + + # To smaller-sized shape + target_dims = [30] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + + # To equal-sized shape + target_dims = [15, 2] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + + # To equal-sized shape + target_dims = [3, 5, 2] + self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + + +if __name__ == '__main__': + unittest.main() From aff5a9e0415ab6ffccd74164471e41744d8f6293 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 17 Oct 2024 18:40:25 +0200 Subject: [PATCH 02/16] Use integer division, fractions don't make sense here anyway. --- dace/codegen/targets/cpp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index b42e310655..9d4c11cd6c 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -420,7 +420,7 @@ def reshape_strides(subset, strides, original_strides, copy_shape): reduced_tile_sizes += [1] * (dims - len(reduced_tile_sizes)) # Pad the remainder with 1s to maintain dimensions. reshaped_copy = copy_shape + [ts for ts in subset.tile_sizes if ts != 1] - reshaped_copy[:len(copy_shape)] = [s / ts for s, ts in zip(copy_shape, reduced_tile_sizes)] + reshaped_copy[:len(copy_shape)] = [s // ts for s, ts in zip(copy_shape, reduced_tile_sizes)] new_strides = [0] * len(reshaped_copy) elements_remaining = functools.reduce(sp.Mul, copy_shape, 1) From a40ba15e54e397f74ac82d28aba381fe699b6864 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 17 Oct 2024 18:42:41 +0200 Subject: [PATCH 03/16] Add even more tests to cover strides and offsets. --- tests/codegen/targets/cpp_test.py | 91 +++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 11 deletions(-) diff --git a/tests/codegen/targets/cpp_test.py b/tests/codegen/targets/cpp_test.py index 4be5fcf4ff..a22cc7bb6e 100644 --- a/tests/codegen/targets/cpp_test.py +++ b/tests/codegen/targets/cpp_test.py @@ -12,63 +12,132 @@ def test_multidim_array_all_dims_unit(self): # To smaller-sized shape target_dims = [1] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, [1]) + self.assertEqual(strides, [1]) # To equal-sized shape target_dims = [1, 1] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, [1, 1]) + self.assertEqual(strides, [1, 1]) # To larger-sized shape target_dims = [1, 1, 1] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, [1, 1, 1]) + self.assertEqual(strides, [1, 1, 1]) def test_multidim_array_some_dims_unit(self): r = Range([(0, 1, 1), (0, 0, 1)]) # To smaller-sized shape target_dims = [2] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [1]) # To equal-sized shape target_dims = [2, 1] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [1, 1]) + # To equal-sized shape, but units first. + target_dims = [1, 2] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [2, 1]) - # To equal-sized shape + # To larger-sized shape. target_dims = [2, 1, 1] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [1, 1, 1]) + # To larger-sized shape, but units first. + target_dims = [1, 1, 2] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [2, 2, 1]) def test_multidim_array_different_shape(self): r = Range([(0, 4, 1), (0, 5, 1)]) # To smaller-sized shape target_dims = [30] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [1]) # To equal-sized shape target_dims = [15, 2] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [2, 1]) - # To equal-sized shape + # To larger-sized shape target_dims = [3, 5, 2] - self.assertEqual(r.num_elements_exact(), reduce(mul, target_dims)) + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [10, 2, 1]) + + def test_from_strided_range(self): + r = Range([(0, 4, 2), (0, 6, 2)]) + + # To smaller-sized shape + target_dims = [12] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [1]) + + # To equal-sized shape + target_dims = [4, 3] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [3, 1]) + + # To larger-sized shape + target_dims = [2, 3, 2] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [6, 2, 1]) + + def test_from_strided_and_offset_range(self): + r = Range([(10, 14, 2), (10, 16, 2)]) + + # To smaller-sized shape + target_dims = [12] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [1]) + + # To equal-sized shape + target_dims = [4, 3] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) + reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) + self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [3, 1]) + + # To larger-sized shape + target_dims = [2, 3, 2] + self.assertEqual(reduce(mul, r.size_exact()), reduce(mul, target_dims)) reshaped, strides = cpp.reshape_strides(r, None, None, target_dims) self.assertEqual(reshaped, target_dims) + self.assertEqual(strides, [6, 2, 1]) if __name__ == '__main__': From 906791c671997e6d1431acc1fcc9e03fc0f7f529 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 10:49:19 +0200 Subject: [PATCH 04/16] Start with cleaning up the imports. --- .../subgraph/subgraph_fusion.py | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 1ff286b85c..ee3babb16b 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -1,30 +1,24 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement subgraph fusion. """ -import dace +import warnings +from collections import defaultdict +from copy import deepcopy as dcpy +from itertools import chain +from typing import List, Tuple + import networkx as nx -from dace import dtypes, registry, symbolic, subsets, data -from dace.sdfg import nodes, utils, replace, SDFG, scope_contains_scope -from dace.sdfg.graph import SubgraphView -from dace.sdfg.scope import ScopeTree +import dace +from dace import dtypes, symbolic, subsets, data from dace.memlet import Memlet -from dace.transformation import transformation from dace.properties import EnumProperty, ListProperty, make_properties, Property -from dace.symbolic import overapproximate -from dace.sdfg.propagation import propagate_memlets_sdfg, propagate_memlet, propagate_memlets_scope, _propagate_node -from dace.transformation.subgraph import helpers -from dace.transformation.dataflow import RedundantArray -from dace.sdfg.utils import consolidate_edges_scope, get_view_node +from dace.sdfg import nodes, SDFG +from dace.sdfg.graph import SubgraphView +from dace.sdfg.propagation import _propagate_node +from dace.sdfg.utils import consolidate_edges_scope +from dace.transformation import transformation from dace.transformation.helpers import find_contiguous_subsets - -from copy import deepcopy as dcpy -from typing import List, Union, Tuple -import warnings - -import dace.libraries.standard as stdlib - -from collections import defaultdict -from itertools import chain +from dace.transformation.subgraph import helpers @make_properties From 58f135a9726b94d3f860ae5f5bb29e5755aa604e Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 10:50:06 +0200 Subject: [PATCH 05/16] Since `edge` is a loop variable, we probably don't want to use that outside the loop. It seems to be a typo on `iedge`. --- dace/transformation/subgraph/subgraph_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index ee3babb16b..1b8f9e2561 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -1094,7 +1094,7 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s # nested SDFG: adjust arrays connected if isinstance(iedge.src, nodes.NestedSDFG): nsdfg = iedge.src.sdfg - nested_data_name = edge.src_conn + nested_data_name = iedge.src_conn self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name, iedge.data) for cedge in out_edges: From e80f9d79c51fbc877ff5808cfd8cdf2d8c7e3229 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 11:17:28 +0200 Subject: [PATCH 06/16] Fix a bunch of docstrings and typehints. --- .../subgraph/subgraph_fusion.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 1b8f9e2561..7813754026 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -4,7 +4,7 @@ from collections import defaultdict from copy import deepcopy as dcpy from itertools import chain -from typing import List, Tuple +from typing import List, Tuple, Set, Iterable import networkx as nx @@ -343,7 +343,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: @staticmethod def get_adjacent_nodes( - sdfg, graph, map_entries) -> Tuple[List[nodes.AccessNode], List[nodes.AccessNode], List[nodes.AccessNode]]: + sdfg, graph, map_entries) -> Tuple[Set[nodes.AccessNode], Set[nodes.AccessNode], Set[nodes.AccessNode]]: """ For given map entries, finds a set of in, out and intermediate nodes as defined below @@ -592,7 +592,7 @@ def adjust_arrays_nsdfg(self, sdfg: dace.sdfg.SDFG, nsdfg: nodes.NestedSDFG, nam @staticmethod def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - intermediate_nodes: List[nodes.AccessNode], + intermediate_nodes: Iterable[nodes.AccessNode], map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit], do_not_override: List[str] = []): @@ -603,12 +603,12 @@ def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, intermediate node as a key. :param sdfg: SDFG - :param state: State of interest + :param graph: State of interest :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph :param map_entries: List of outermost scoped map entries in the subgraph :param map_exits: List of map exits corresponding to map_entries in order :param do_not_override: List of data array names not to be compressed - :param return: A dictionary indicating for each data string whether its array can be compressed + :return: A dictionary indicating for each data string whether its array can be compressed """ # search whether intermediate_nodes appear outside of subgraph @@ -642,14 +642,14 @@ def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, return subgraph_contains_data def clone_intermediate_nodes(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - intermediate_nodes: List[nodes.AccessNode], out_nodes: List[nodes.AccessNode], + intermediate_nodes: Set[nodes.AccessNode], out_nodes: Set[nodes.AccessNode], map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit]): """ Creates cloned access nodes and data arrays for nodes that are both in intermediate nodes and out nodes, redirecting output from the original node to the cloned node. Operates in-place. :param sdfg: SDFG - :param state: State of interest + :param graph: State of interest :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph :param out_nodes: List of out nodes appearing in a fusible subgraph :param map_entries: List of outermost scoped map entries in the subgraph @@ -688,7 +688,7 @@ def clone_intermediate_nodes(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGSt return transients_created def determine_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - intermediate_nodes: List[nodes.AccessNode], map_entries: List[nodes.MapEntry], + intermediate_nodes: Iterable[nodes.AccessNode], map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit]): """ Determines the invariant dimensions for each node -- dimensions in @@ -696,7 +696,7 @@ def determine_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg. exits does not change. :param sdfg: SDFG - :param state: State of interest + :param graph: State of interest :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph :param map_entries: List of outermost scoped map entries in the subgraph :param map_exits: List of map exits corresponding to map_entries in order @@ -726,9 +726,9 @@ def determine_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg. def prepare_intermediate_nodes(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - in_nodes: List[nodes.AccessNode], - out_nodes: List[nodes.AccessNode], - intermediate_nodes: List[nodes.AccessNode], + in_nodes: Set[nodes.AccessNode], + out_nodes: Set[nodes.AccessNode], + intermediate_nodes: Set[nodes.AccessNode], map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit], do_not_override: List[str] = []): @@ -814,7 +814,7 @@ def fuse(self, global_map_entry = nodes.MapEntry(global_map) global_map_exit = nodes.MapExit(global_map) - schedule = map_entries[0].schedule + schedule = map_entries[0].map.schedule global_map_entry.schedule = schedule graph.add_node(global_map_entry) graph.add_node(global_map_exit) From 1393052cc2acad472d585b799f19eb72449ebf17 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 11:21:10 +0200 Subject: [PATCH 07/16] Remove a couple of redundant `== True`s. --- dace/transformation/subgraph/subgraph_fusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 7813754026..ec905be772 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -254,7 +254,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: return False # 2.6 Check for disjoint accesses for arrays that cannot be compressed - if self.disjoint_subsets == True: + if self.disjoint_subsets: container_dict = defaultdict(list) for node in chain(in_nodes, intermediate_nodes, out_nodes): if isinstance(node, nodes.AccessNode): @@ -335,7 +335,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: intersection = rng_1dim.intersects(orng_1dim) except TypeError: return False - if intersection is None or intersection == True: + if intersection is None or intersection: warnings.warn("SubgraphFusion::Disjoint Accesses found!") return False From f6bc8fb62898d376a4cd1738b27b623d7a74aff4 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 11:23:32 +0200 Subject: [PATCH 08/16] Since `port_created` is only a pair (or None), it does not have `st` and `nd` fields. --- dace/transformation/subgraph/subgraph_fusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index ec905be772..d5c9b65795 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -963,8 +963,7 @@ def fuse(self, port_created = (in_conn, out_conn) else: - in_conn = port_created.st - out_conn = port_created.nd + in_conn, out_conn = port_created # map graph.add_edge(global_map_exit, out_conn, dst, out_edge.dst_conn, dcpy(out_edge.data)) From e518a55f6a41cbd1572961dfa96b4e39e651b07b Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 11:26:35 +0200 Subject: [PATCH 09/16] Fix another typehint. --- dace/transformation/subgraph/subgraph_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index d5c9b65795..a01786cd0e 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -537,7 +537,7 @@ def copy_edge(self, graph.remove_edge(edge) return ret - def adjust_arrays_nsdfg(self, sdfg: dace.sdfg.SDFG, nsdfg: nodes.NestedSDFG, name: str, nname: str, memlet: Memlet): + def adjust_arrays_nsdfg(self, sdfg: dace.sdfg.SDFG, nsdfg: dace.sdfg.SDFG, name: str, nname: str, memlet: Memlet): """ DFS to replace strides and volumes of data that exhibits nested SDFGs adjacent to its corresponding access nodes, applied during post-processing From b9c718a0658c7e18241c2a78b2b27185f7fe7476 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 11:30:57 +0200 Subject: [PATCH 10/16] Fixes for various minor PEP 8 complaints (which are otherwise no-op). --- .../subgraph/subgraph_fusion.py | 61 +++++++++---------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index a01786cd0e..1162c7ebad 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -40,26 +40,26 @@ class SubgraphFusion(transformation.SubgraphTransformation): transient_allocation = EnumProperty(dtype=dtypes.StorageType, desc="Storage Location to push transients to that are " - "fully contained within the subgraph.", + "fully contained within the subgraph.", default=dtypes.StorageType.Default) schedule_innermaps = Property(desc="Schedule of inner maps. If none, " - "keeps schedule.", + "keeps schedule.", dtype=dtypes.ScheduleType, default=None, allow_none=True) consolidate = Property(desc="Consolidate edges that enter and exit the fused map.", dtype=bool, default=False) propagate = Property(desc="Propagate memlets of edges that enter and exit the fused map." - "Disable if this causes problems (e.g., if memlet propagation does" - "not work correctly).", + "Disable if this causes problems (e.g., if memlet propagation does" + "not work correctly).", dtype=bool, default=True) disjoint_subsets = Property(desc="Check for disjoint subsets in can_be_applied. If multiple" - "access nodes pointing to the same data appear within a subgraph" - "to be fused, this check confirms that their access sets are" - "independent per iteration space to avoid race conditions.", + "access nodes pointing to the same data appear within a subgraph" + "to be fused, this check confirms that their access sets are" + "independent per iteration space to avoid race conditions.", dtype=bool, default=True) @@ -160,8 +160,8 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: if in_edge.src in map_exits: for iedge in graph.in_edges(in_edge.src): if iedge.dst_conn[2:] == in_edge.src_conn[3:]: - subset_to_add = dcpy(iedge.data.subset if iedge.data.data == - node.data else iedge.data.other_subset) + subset_to_add = dcpy(iedge.data.subset + if iedge.data.data == node.data else iedge.data.other_subset) subset_to_add.pop(dims_to_discard) upper_subsets.add(subset_to_add) @@ -177,8 +177,8 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: if out_edge.dst in map_entries: for oedge in graph.out_edges(out_edge.dst): if oedge.src_conn and oedge.src_conn[3:] == out_edge.dst_conn[2:]: - subset_to_add = dcpy(oedge.data.subset if oedge.data.data == - node.data else oedge.data.other_subset) + subset_to_add = dcpy(oedge.data.subset + if oedge.data.data == node.data else oedge.data.other_subset) subset_to_add.pop(dims_to_discard) lower_subsets.add(subset_to_add) @@ -329,8 +329,8 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: subset_minus.replace(repl_dict) for (rng, orng) in zip(subset_plus, subset_minus): - rng_1dim = subsets.Range((rng, )) - orng_1dim = subsets.Range((orng, )) + rng_1dim = subsets.Range((rng,)) + orng_1dim = subsets.Range((orng,)) try: intersection = rng_1dim.intersects(orng_1dim) except TypeError: @@ -471,11 +471,11 @@ def get_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGSt if in_edge.src in map_exits: other_edge = graph.memlet_path(in_edge)[-2] other_subset = other_edge.data.subset \ - if other_edge.data.data == node.data \ - else other_edge.data.other_subset + if other_edge.data.data == node.data \ + else other_edge.data.other_subset for (idx, (ssbs1, ssbs2)) \ - in enumerate(zip(in_edge.data.subset, other_subset)): + in enumerate(zip(in_edge.data.subset, other_subset)): if ssbs1 != ssbs2: variant_dimensions.add(idx) else: @@ -494,8 +494,8 @@ def get_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGSt for other_edge in graph.out_edges(out_edge.dst): if other_edge.src_conn and other_edge.src_conn[3:] == out_edge.dst_conn[2:]: other_subset = other_edge.data.subset \ - if other_edge.data.data == node.data \ - else other_edge.data.other_subset + if other_edge.data.data == node.data \ + else other_edge.data.other_subset for (idx, (ssbs1, ssbs2)) in enumerate(zip(out_edge.data.subset, other_subset)): if ssbs1 != ssbs2: variant_dimensions.add(idx) @@ -629,15 +629,15 @@ def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, # if so, add to data_counter_subgraph # do not add if it is in out_nodes / in_nodes if state == graph and \ - (node in intermediate_nodes or scope_dict[node] in map_entries): + (node in intermediate_nodes or scope_dict[node] in map_entries): data_counter_subgraph[node.data] += 1 # next up: If intermediate_counter and global counter match and if the array # is declared transient, it is fully contained by the subgraph - subgraph_contains_data = {data: data_counter[data] == data_counter_subgraph[data] \ - and sdfg.data(data).transient \ - and data not in do_not_override \ + subgraph_contains_data = {data: data_counter[data] == data_counter_subgraph[data] + and sdfg.data(data).transient + and data not in do_not_override for data in data_intermediate} return subgraph_contains_data @@ -688,7 +688,8 @@ def clone_intermediate_nodes(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGSt return transients_created def determine_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - intermediate_nodes: Iterable[nodes.AccessNode], map_entries: List[nodes.MapEntry], + intermediate_nodes: Iterable[nodes.AccessNode], + map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit]): """ Determines the invariant dimensions for each node -- dimensions in @@ -825,9 +826,9 @@ def fuse(self, # intermediate_nodes simultaneously # also check which dimensions of each transient data element correspond # to map axes and write this information into a dict. - node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, \ - intermediate_nodes,\ - map_entries, map_exits, \ + node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, + intermediate_nodes, + map_entries, map_exits, do_not_override) (subgraph_contains_data, transients_created, invariant_dimensions) = node_info @@ -867,13 +868,11 @@ def fuse(self, inconnectors_dict[src] = (edge, in_conn, out_conn) # reroute in edge via global_map_entry - self.copy_edge(graph, edge, new_dst = global_map_entry, \ - new_dst_conn = in_conn) + self.copy_edge(graph, edge, new_dst=global_map_entry, new_dst_conn=in_conn) # map out edges to new map for out_edge in out_edges: - self.copy_edge(graph, out_edge, new_src = global_map_entry, \ - new_src_conn = out_conn) + self.copy_edge(graph, out_edge, new_src=global_map_entry, new_src_conn=out_conn) else: # connect directly @@ -1113,7 +1112,7 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s if len(in_edges) > 1: for oedge in out_edges: if oedge.dst == global_map_exit and \ - oedge.data.other_subset is None: + oedge.data.other_subset is None: oedge.data.other_subset = dcpy(oedge.data.subset) oedge.data.other_subset.offset(min_offset, True) From 778582501777d8b91592bc59070beff69ea9d6b9 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 12:15:39 +0200 Subject: [PATCH 11/16] Typehints and set literal in helpers. --- dace/transformation/helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 6ca4602079..b003ca36d1 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -7,7 +7,7 @@ from dace.sdfg.state import ControlFlowRegion from dace.subsets import Range, Subset, union import dace.subsets as subsets -from typing import Dict, List, Optional, Tuple, Set, Union +from typing import Dict, List, Optional, Tuple, Set, Union, Iterable from dace import data, dtypes, symbolic from dace.codegen import control_flow as cf @@ -275,7 +275,7 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: if isinstance(child, cf.BasicCFBlock): if child.state in visited: continue - components[child.state] = (set([child.state]), child) + components[child.state] = ({child.state}, child) visited[child.state] = False elif isinstance(child, (cf.ForScope, cf.WhileScope)): guard = child.guard @@ -1031,11 +1031,11 @@ def are_subsets_contiguous(subset_a: subsets.Subset, subset_b: subsets.Subset, d return False -def find_contiguous_subsets(subset_list: List[subsets.Subset], dim: int = None) -> Set[subsets.Subset]: +def find_contiguous_subsets(subset_list: Iterable[subsets.Subset], dim: int = None) -> Set[subsets.Subset]: """ Finds the set of largest contiguous subsets in a list of subsets. - :param subsets: Iterable of subset objects. + :param subset_list: Iterable of subset objects. :param dim: Check for contiguity only for the specified dimension. :return: A list of contiguous subsets. """ From 3418d9e74ec849c43c8ca5d68b8fd33222ffaf0a Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 12:20:53 +0200 Subject: [PATCH 12/16] Move various helper functions out of `@staticmethod`, since they don't need to be tied to the class itself. Remove their unnecessary arguments. --- .../transformation/subgraph/stencil_tiling.py | 6 +- .../subgraph/subgraph_fusion.py | 915 +++++++++--------- 2 files changed, 460 insertions(+), 461 deletions(-) diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 1ba86252c4..c9c3e9afd4 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -200,13 +200,13 @@ def can_be_applied(sdfg, subgraph) -> bool: # get intermediate_nodes, out_nodes from SubgraphFusion Transformation try: - node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph, map_entries) + node_config = get_adjacent_nodes(sdfg, graph, map_entries) (_, intermediate_nodes, out_nodes) = node_config except NotImplementedError: return False # 1.4: check topological feasibility - if not SubgraphFusion.check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes): + if not check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes): return False # 1.5 nodes that are both intermediate and out nodes # are not supported in StencilTiling @@ -215,7 +215,7 @@ def can_be_applied(sdfg, subgraph) -> bool: # 1.6 check that we only deal with compressible transients - subgraph_contains_data = SubgraphFusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes, + subgraph_contains_data = determine_compressible_nodes(sdfg, graph, intermediate_nodes, map_entries, map_exits) if any([s == False for s in subgraph_contains_data.values()]): return False diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 1162c7ebad..5c97c6845f 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -21,6 +21,433 @@ from dace.transformation.subgraph import helpers +def get_invariant_dimensions(graph: dace.sdfg.SDFGState, + map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit], + node: nodes.AccessNode): + """ + For a given intermediate access node, return a set of indices that correspond to array / subset dimensions in which + no change is observed upon propagation through the corresponding map nodes in map_entries / map_exits. + + :param map_entries: List of outermost scoped map entries + :param map_exits: List of corresponding exit nodes to map_entries, in order + :param node: Intermediate access node of interest + :return: Set of invariant integer dimensions + """ + variant_dimensions = set() + subset_length = -1 + + for in_edge in graph.in_edges(node): + other_subset = None + if in_edge.src in map_exits: + other_edge = graph.memlet_path(in_edge)[-2] + other_subset = other_edge.data.subset \ + if other_edge.data.data == node.data \ + else other_edge.data.other_subset + + for (idx, (ssbs1, ssbs2)) \ + in enumerate(zip(in_edge.data.subset, other_subset)): + if ssbs1 != ssbs2: + variant_dimensions.add(idx) + else: + warnings.warn("SubgraphFusion::Nodes between two maps to be" + "fused with *incoming* edges" + "from outside the maps are not" + "allowed yet.") + + if subset_length < 0: + subset_length = other_subset.dims() + else: + assert other_subset.dims() == subset_length + + for out_edge in graph.out_edges(node): + if out_edge.dst in map_entries: + for other_edge in graph.out_edges(out_edge.dst): + if other_edge.src_conn and other_edge.src_conn[3:] == out_edge.dst_conn[2:]: + other_subset = other_edge.data.subset \ + if other_edge.data.data == node.data \ + else other_edge.data.other_subset + for (idx, (ssbs1, ssbs2)) in enumerate(zip(out_edge.data.subset, other_subset)): + if ssbs1 != ssbs2: + variant_dimensions.add(idx) + assert other_subset.dims() == subset_length + + invariant_dimensions = set([i for i in range(subset_length)]) - variant_dimensions + return invariant_dimensions + + +def get_adjacent_nodes(graph, map_entries) \ + -> Tuple[Set[nodes.AccessNode], Set[nodes.AccessNode], Set[nodes.AccessNode]]: + """ + For given map entries, finds a set of in, out and intermediate nodes as defined below + + :param graph: State of interest + :param map_entries: List of all outermost scoped maps that induce the subgraph + :return: Tuple of (in_nodes, intermediate_nodes, out_nodes) + + - In_nodes are nodes that serve as pure input nodes for the map entries + - Out nodes are nodes that serve as pure output nodes for the map entries + - Interemdiate nodes are nodes that serve as buffer storage between outermost scoped map entries and exits + of the induced subgraph + + -> in_nodes are trivially disjoint from the other two types of access nodes + -> Intermediate_nodes and out_nodes are not necessarily disjoint + + """ + + # Nodes that flow into one or several maps but no data is flowed to them from any map + in_nodes = set() + + # Nodes into which data is flowed but that no data flows into any map from them + out_nodes = set() + + # Nodes that act as intermediate node - data flows from a map into them and then there + # is an outgoing path into another map + intermediate_nodes = set() + + map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] + for map_entry, map_exit in zip(map_entries, map_exits): + for edge in graph.in_edges(map_entry): + in_nodes.add(edge.src) + for edge in graph.out_edges(map_exit): + current_node = edge.dst + if len(graph.out_edges(current_node)) == 0: + out_nodes.add(current_node) + else: + for dst_edge in graph.out_edges(current_node): + if dst_edge.dst in map_entries: + # add to intermediate_nodes + intermediate_nodes.add(current_node) + + else: + # add to out_nodes + out_nodes.add(current_node) + + # any intermediate_nodes currently in in_nodes shouldn't be there + in_nodes -= intermediate_nodes + + for node in intermediate_nodes: + for e in graph.in_edges(node): + if e.src not in map_exits: + warnings.warn("SubgraphFusion::Nodes between two maps to be" + "fused with *incoming* edges" + "from outside the maps are not" + "allowed yet.") + raise NotImplementedError() + + return in_nodes, intermediate_nodes, out_nodes + + +def copy_edge(graph, + edge, + new_src=None, + new_src_conn=None, + new_dst=None, + new_dst_conn=None, + new_data=None, + remove_old=False): + """ + Copies an edge going from source to dst. + If no destination is specified, the edge is copied with the same + destination and port as the original edge, else the edge is copied + with the new destination and the new port. + If no source is specified, the edge is copied with the same + source and port as the original edge, else the edge is copied + with the new source and the new port + If remove_old is specified, the old edge is removed immediately + If new_data is specified, inserts new_data as a memlet, + else makes a deepcopy of the current edges memlet + """ + data = new_data if new_data else dcpy(edge.data) + src = edge.src if new_src is None else new_src + src_conn = edge.src_conn if new_src is None else new_src_conn + dst = edge.dst if new_dst is None else new_dst + dst_conn = edge.dst_conn if new_dst is None else new_dst_conn + + ret = graph.add_edge(src, src_conn, dst, dst_conn, data) + + if remove_old: + graph.remove_edge(edge) + return ret + + +def clone_intermediate_nodes(sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, + intermediate_nodes: Set[nodes.AccessNode], out_nodes: Set[nodes.AccessNode], + map_entries: List[nodes.MapEntry]): + """ + Creates cloned access nodes and data arrays for nodes that are both in intermediate nodes + and out nodes, redirecting output from the original node to the cloned node. Operates in-place. + + :param sdfg: SDFG + :param graph: State of interest + :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph + :param out_nodes: List of out nodes appearing in a fusible subgraph + :param map_entries: List of outermost scoped map entries in the subgraph + :return: A dict that maps each intermediate node that also functions as an out node + to the respective cloned transient node + """ + + transients_created = {} + for node in intermediate_nodes & out_nodes: + # create new transient at exit replacing the array + # and redirect all traffic + data_ref = sdfg.data(node.data) + + out_trans_data_name = node.data + '_OUT' + out_trans_data_name = sdfg._find_new_name(out_trans_data_name) + data_trans = sdfg.add_transient(name=out_trans_data_name, + shape=dcpy(data_ref.shape), + dtype=dcpy(data_ref.dtype), + storage=dcpy(data_ref.storage), + offset=dcpy(data_ref.offset)) + node_trans = graph.add_access(out_trans_data_name) + if node.setzero: + node_trans.setzero = True + + # redirect all relevant traffic from node_trans to node + edges = list(graph.out_edges(node)) + for edge in edges: + if edge.dst not in map_entries: + copy_edge(graph, edge, new_src=node_trans, remove_old=True) + + graph.add_edge(node, None, node_trans, None, Memlet()) + + transients_created[node] = node_trans + + return transients_created + + +def determine_invariant_dimensions(graph: dace.sdfg.SDFGState, + intermediate_nodes: Iterable[nodes.AccessNode], + map_entries: List[nodes.MapEntry], + map_exits: List[nodes.MapExit]): + """ + Determines the invariant dimensions for each node -- dimensions in + which the access set of the memlets propagated through map entries and + exits does not change. + + :param graph: State of interest + :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph + :param map_entries: List of outermost scoped map entries in the subgraph + :param map_exits: List of map exits corresponding to map_entries in order + :return: A dict mapping each intermediate node (nodes.AccessNode) to a list of integer dimensions + """ + # create dict for every array that for which + # subgraph_contains_data is true that lists invariant axes. + invariant_dimensions = {} + for node in intermediate_nodes: + data = node.data + inv_dims = get_invariant_dimensions(graph, map_entries, map_exits, node) + if node in invariant_dimensions: + # do a check -- we want the same result for each + # node containing the same data + if not inv_dims == invariant_dimensions[node]: + warnings.warn(f"SubgraphFusion::Data dimensions that are not propagated through differ" + "across multiple instances of access nodes for data {node.data}" + "Please check whether all memlets to AccessNodes containing" + "this data are sound.") + invariant_dimensions[data] |= inv_dims + + else: + invariant_dimensions[data] = inv_dims + + return invariant_dimensions + + +def prepare_intermediate_nodes(sdfg: dace.sdfg.SDFG, + graph: dace.sdfg.SDFGState, + out_nodes: Set[nodes.AccessNode], + intermediate_nodes: Set[nodes.AccessNode], + map_entries: List[nodes.MapEntry], + map_exits: List[nodes.MapExit], + do_not_override: List[str] = []): + """ + Helper function that computes the following information: + 1. Determine whether intermediate nodes only appear within the induced fusible subgraph. This is equivalent to checking for compresssibility. + 2. Determine whether any intermediate transients are also out nodes, if so they have to be cloned + 3. Determine invariant dimensions for any intermediate transients (that are compressible). + + :return: A tuple (subgraph_contains_data, transients_created, invariant_dimensions) + of dictionaries containing the necessary information + """ + + # 1. Compressibility + subgraph_contains_data = determine_compressible_nodes(sdfg, graph, intermediate_nodes, + map_entries, do_not_override) + # 2. Clone intermediate & out transients + transients_created = clone_intermediate_nodes(sdfg, graph, intermediate_nodes, out_nodes, map_entries) + # 3. Gather invariant dimensions + invariant_dimensions = determine_invariant_dimensions(graph, intermediate_nodes, map_entries, + map_exits) + + return subgraph_contains_data, transients_created, invariant_dimensions + + +def adjust_arrays_nsdfg(sdfg: dace.sdfg.SDFG, nsdfg: dace.sdfg.SDFG, name: str, nname: str, memlet: Memlet): + """ + DFS to replace strides and volumes of data that exhibits nested SDFGs + adjacent to its corresponding access nodes, applied during post-processing + of a fused graph. Operates in-place. + + :param sdfg: SDFG + :param nsdfg: The Nested SDFG of interest + :param name: Name of the array in the SDFG + :param nname: Name of the array in the nested SDFG + :param memlet: Memlet adjacent to the nested SDFG that leads to the + access node with the corresponding data name + """ + # check whether array needs to change + if len(sdfg.data(name).shape) != len(nsdfg.data(nname).shape): + subset_copy = dcpy(memlet.subset) + non_ones = subset_copy.squeeze() + strides = [] + total_size = 1 + + if non_ones: + strides = [] + total_size = 1 + for (i, (sh, st)) in enumerate(zip(sdfg.data(name).shape, sdfg.data(name).strides)): + if i in non_ones: + strides.append(st) + total_size *= sh + else: + strides = [1] + total_size = 1 + + if isinstance(nsdfg.data(nname), data.Array): + nsdfg.data(nname).strides = tuple(strides) + nsdfg.data(nname).total_size = total_size + + else: + if isinstance(nsdfg.data(nname), data.Array): + nsdfg.data(nname).strides = sdfg.data(name).strides + nsdfg.data(nname).total_size = sdfg.data(name).total_size + + # traverse the whole graph and search for arrays + for ngraph in nsdfg.nodes(): + for nnode in ngraph.nodes(): + if isinstance(nnode, nodes.AccessNode) and nnode.label == nname: + # trace and recurse if necessary + for e in chain(ngraph.out_edges(nnode), ngraph.in_edges(nnode)): + for te in ngraph.memlet_tree(e): + if isinstance(te.dst, nodes.NestedSDFG): + adjust_arrays_nsdfg(nsdfg, te.dst.sdfg, nname, te.dst_conn, te.data) + if isinstance(te.src, nodes.NestedSDFG): + adjust_arrays_nsdfg(nsdfg, te.src.sdfg, nname, te.src_conn, te.data) + + +def check_topo_feasibility(graph, map_entries, intermediate_nodes, out_nodes): + """ + Checks whether given outermost scoped map entries have topological structure apt for fusion + + :param graph: State + :param map_entries: List of outermost scoped map entries induced by subgraph + :param intermediate_nodes: List of intermediate access nodes + :param out_nodes: List of outgoing access nodes + :return: Boolean value indicating fusibility + """ + # For each intermediate and out node: must never reach any map + # entry if it is not connected to map entry immediately + + # for memoization purposes + visited = set() + + def visit_descendants(graph, node, visited, map_entries): + # check whether the node has already been processed once + if node in visited: + return True + # check whether the node is in our map entries. + if node in map_entries: + return False + # for every out edge, continue exploring whether + # we and up at another map entry that is in our set + for oedge in graph.out_edges(node): + if not visit_descendants(graph, oedge.dst, visited, map_entries): + return False + + # this node does not lead to any other map entries, add to visited + visited.add(node) + return True + + for node in intermediate_nodes | out_nodes: + # these nodes must not lead to a map entry + nodes_to_check = set() + for oedge in graph.out_edges(node): + if oedge.dst not in map_entries: + nodes_to_check.add(oedge.dst) + + for forbidden_node in nodes_to_check: + if not visit_descendants(graph, forbidden_node, visited, map_entries): + return False + + return True + + +def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, + graph: dace.sdfg.SDFGState, + intermediate_nodes: Iterable[nodes.AccessNode], + map_entries: List[nodes.MapEntry], + do_not_override: List[str] = []): + """ + Checks for all intermediate nodes whether they appear + only within the induced fusible subgraph my map_entries and map_exits. + This is returned as a dict that contains a boolean value for each + intermediate node as a key. + + :param sdfg: SDFG + :param graph: State of interest + :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph + :param map_entries: List of outermost scoped map entries in the subgraph + :param do_not_override: List of data array names not to be compressed + :return: A dictionary indicating for each data string whether its array can be compressed + """ + + # search whether intermediate_nodes appear outside of subgraph + # and store it in dict + data_counter = defaultdict(int) + data_counter_subgraph = defaultdict(int) + + data_intermediate = set([node.data for node in intermediate_nodes]) + + # do a full global search and count each data from each intermediate node + scope_dict = graph.scope_dict() + for state in sdfg.nodes(): + for node in state.nodes(): + if isinstance(node, nodes.AccessNode) and node.data in data_intermediate: + # add them to the counter set in all cases + data_counter[node.data] += 1 + # see whether we are inside the subgraph scope + # if so, add to data_counter_subgraph + # do not add if it is in out_nodes / in_nodes + if state == graph and \ + (node in intermediate_nodes or scope_dict[node] in map_entries): + data_counter_subgraph[node.data] += 1 + + # next up: If intermediate_counter and global counter match and if the array + # is declared transient, it is fully contained by the subgraph + + subgraph_contains_data = {data: data_counter[data] == data_counter_subgraph[data] + and sdfg.data(data).transient + and data not in do_not_override + for data in data_intermediate} + return subgraph_contains_data + + +def change_data(transient_array: dace.data.Array, shape, strides, total_size, offset, lifetime, storage): + """Compress original shape""" + if shape is not None: + transient_array.shape = shape + if strides is not None: + transient_array.strides = strides + if total_size is not None: + transient_array.total_size = total_size + if offset is not None: + transient_array.offset = offset + if lifetime is not None: + transient_array.lifetime = lifetime + if storage is not None: + transient_array.storage = storage + + @make_properties class SubgraphFusion(transformation.SubgraphTransformation): """ @@ -120,13 +547,13 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: # 2.1 do some preparation work first: # calculate node topology (see apply for definition) try: - node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph, map_entries) + node_config = get_adjacent_nodes(graph, map_entries) except NotImplementedError: return False in_nodes, intermediate_nodes, out_nodes = node_config # 2.2 topological feasibility: - if not SubgraphFusion.check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes): + if not check_topo_feasibility(graph, map_entries, intermediate_nodes, out_nodes): return False # 2.3 memlet feasibility @@ -135,8 +562,8 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: # of the next entering map. # We also check for any WCRs on the fly. try: - invariant_dimensions = self.determine_invariant_dimensions(sdfg, graph, intermediate_nodes, map_entries, - map_exits) + invariant_dimensions = determine_invariant_dimensions(graph, intermediate_nodes, map_entries, + map_exits) except NotImplementedError: return False @@ -232,12 +659,13 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: for out_node in out_nodes: for in_edge in graph.in_edges(out_node): if in_edge.src in map_exits and in_edge.data.wcr: - if in_edge.data.data in in_data or in_edge.data.data in intermediate_data or in_edge.data.data in view_data: + if (in_edge.data.data in in_data + or in_edge.data.data in intermediate_data + or in_edge.data.data in view_data): return False # Check compressibility for each intermediate node -- this is needed in the following checks - is_compressible = SubgraphFusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes, map_entries, - map_exits) + is_compressible = determine_compressible_nodes(sdfg, graph, intermediate_nodes, map_entries) # 2.5 Intermediate Arrays must not connect to ArrayViews for n in intermediate_nodes: @@ -341,420 +769,6 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: return True - @staticmethod - def get_adjacent_nodes( - sdfg, graph, map_entries) -> Tuple[Set[nodes.AccessNode], Set[nodes.AccessNode], Set[nodes.AccessNode]]: - """ - For given map entries, finds a set of in, out and intermediate nodes as defined below - - :param sdfg: SDFG - :param graph: State of interest - :param map_entries: List of all outermost scoped maps that induce the subgraph - :return: Tuple of (in_nodes, intermediate_nodes, out_nodes) - - - In_nodes are nodes that serve as pure input nodes for the map entries - - Out nodes are nodes that serve as pure output nodes for the map entries - - Interemdiate nodes are nodes that serve as buffer storage between outermost scoped map entries and exits - of the induced subgraph - - -> in_nodes are trivially disjoint from the other two types of access nodes - -> Intermediate_nodes and out_nodes are not necessarily disjoint - - """ - - # Nodes that flow into one or several maps but no data is flowed to them from any map - in_nodes = set() - - # Nodes into which data is flowed but that no data flows into any map from them - out_nodes = set() - - # Nodes that act as intermediate node - data flows from a map into them and then there - # is an outgoing path into another map - intermediate_nodes = set() - - map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] - for map_entry, map_exit in zip(map_entries, map_exits): - for edge in graph.in_edges(map_entry): - in_nodes.add(edge.src) - for edge in graph.out_edges(map_exit): - current_node = edge.dst - if len(graph.out_edges(current_node)) == 0: - out_nodes.add(current_node) - else: - for dst_edge in graph.out_edges(current_node): - if dst_edge.dst in map_entries: - # add to intermediate_nodes - intermediate_nodes.add(current_node) - - else: - # add to out_nodes - out_nodes.add(current_node) - - # any intermediate_nodes currently in in_nodes shouldn't be there - in_nodes -= intermediate_nodes - - for node in intermediate_nodes: - for e in graph.in_edges(node): - if e.src not in map_exits: - warnings.warn("SubgraphFusion::Nodes between two maps to be" - "fused with *incoming* edges" - "from outside the maps are not" - "allowed yet.") - raise NotImplementedError() - - return (in_nodes, intermediate_nodes, out_nodes) - - @staticmethod - def check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes): - """ - Checks whether given outermost scoped map entries have topological structure apt for fusion - - :param sdfg: SDFG - :param graph: State - :param map_entries: List of outermost scoped map entries induced by subgraph - :param intermediate_nodes: List of intermediate access nodes - :param out_nodes: List of outgoing access nodes - :return: Boolean value indicating fusibility - """ - # For each intermediate and out node: must never reach any map - # entry if it is not connected to map entry immediately - - # for memoization purposes - visited = set() - - def visit_descendants(graph, node, visited, map_entries): - # check whether the node has already been processed once - if node in visited: - return True - # check whether the node is in our map entries. - if node in map_entries: - return False - # for every out edge, continue exploring whether - # we and up at another map entry that is in our set - for oedge in graph.out_edges(node): - if not visit_descendants(graph, oedge.dst, visited, map_entries): - return False - - # this node does not lead to any other map entries, add to visited - visited.add(node) - return True - - for node in intermediate_nodes | out_nodes: - # these nodes must not lead to a map entry - nodes_to_check = set() - for oedge in graph.out_edges(node): - if oedge.dst not in map_entries: - nodes_to_check.add(oedge.dst) - - for forbidden_node in nodes_to_check: - if not visit_descendants(graph, forbidden_node, visited, map_entries): - return False - - return True - - def get_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit], - node: nodes.AccessNode): - """ - For a given intermediate access node, return a set of indices that correspond to array / subset dimensions in which no change is observed - upon propagation through the corresponding map nodes in map_entries / map_exits. - - :param map_entries: List of outermost scoped map entries - :param map_exits: List of corresponding exit nodes to map_entries, in order - :param node: Intermediate access node of interest - :return: Set of invariant integer dimensions - """ - variant_dimensions = set() - subset_length = -1 - - for in_edge in graph.in_edges(node): - if in_edge.src in map_exits: - other_edge = graph.memlet_path(in_edge)[-2] - other_subset = other_edge.data.subset \ - if other_edge.data.data == node.data \ - else other_edge.data.other_subset - - for (idx, (ssbs1, ssbs2)) \ - in enumerate(zip(in_edge.data.subset, other_subset)): - if ssbs1 != ssbs2: - variant_dimensions.add(idx) - else: - warnings.warn("SubgraphFusion::Nodes between two maps to be" - "fused with *incoming* edges" - "from outside the maps are not" - "allowed yet.") - - if subset_length < 0: - subset_length = other_subset.dims() - else: - assert other_subset.dims() == subset_length - - for out_edge in graph.out_edges(node): - if out_edge.dst in map_entries: - for other_edge in graph.out_edges(out_edge.dst): - if other_edge.src_conn and other_edge.src_conn[3:] == out_edge.dst_conn[2:]: - other_subset = other_edge.data.subset \ - if other_edge.data.data == node.data \ - else other_edge.data.other_subset - for (idx, (ssbs1, ssbs2)) in enumerate(zip(out_edge.data.subset, other_subset)): - if ssbs1 != ssbs2: - variant_dimensions.add(idx) - assert other_subset.dims() == subset_length - - invariant_dimensions = set([i for i in range(subset_length)]) - variant_dimensions - return invariant_dimensions - - def copy_edge(self, - graph, - edge, - new_src=None, - new_src_conn=None, - new_dst=None, - new_dst_conn=None, - new_data=None, - remove_old=False): - """ - Copies an edge going from source to dst. - If no destination is specified, the edge is copied with the same - destination and port as the original edge, else the edge is copied - with the new destination and the new port. - If no source is specified, the edge is copied with the same - source and port as the original edge, else the edge is copied - with the new source and the new port - If remove_old is specified, the old edge is removed immediately - If new_data is specified, inserts new_data as a memlet, else - else makes a deepcopy of the current edges memlet - """ - data = new_data if new_data else dcpy(edge.data) - src = edge.src if new_src is None else new_src - src_conn = edge.src_conn if new_src is None else new_src_conn - dst = edge.dst if new_dst is None else new_dst - dst_conn = edge.dst_conn if new_dst is None else new_dst_conn - - ret = graph.add_edge(src, src_conn, dst, dst_conn, data) - - if remove_old: - graph.remove_edge(edge) - return ret - - def adjust_arrays_nsdfg(self, sdfg: dace.sdfg.SDFG, nsdfg: dace.sdfg.SDFG, name: str, nname: str, memlet: Memlet): - """ - DFS to replace strides and volumes of data that exhibits nested SDFGs - adjacent to its corresponding access nodes, applied during post-processing - of a fused graph. Operates in-place. - - :param sdfg: SDFG - :param nsdfg: The Nested SDFG of interest - :param name: Name of the array in the SDFG - :param nname: Name of the array in the nested SDFG - :param memlet: Memlet adjacent to the nested SDFG that leads to the - access node with the corresponding data name - """ - # check whether array needs to change - if len(sdfg.data(name).shape) != len(nsdfg.data(nname).shape): - subset_copy = dcpy(memlet.subset) - non_ones = subset_copy.squeeze() - strides = [] - total_size = 1 - - if non_ones: - strides = [] - total_size = 1 - for (i, (sh, st)) in enumerate(zip(sdfg.data(name).shape, sdfg.data(name).strides)): - if i in non_ones: - strides.append(st) - total_size *= sh - else: - strides = [1] - total_size = 1 - - if isinstance(nsdfg.data(nname), data.Array): - nsdfg.data(nname).strides = tuple(strides) - nsdfg.data(nname).total_size = total_size - - else: - if isinstance(nsdfg.data(nname), data.Array): - nsdfg.data(nname).strides = sdfg.data(name).strides - nsdfg.data(nname).total_size = sdfg.data(name).total_size - - # traverse the whole graph and search for arrays - for ngraph in nsdfg.nodes(): - for nnode in ngraph.nodes(): - if isinstance(nnode, nodes.AccessNode) and nnode.label == nname: - # trace and recurse if necessary - for e in chain(ngraph.out_edges(nnode), ngraph.in_edges(nnode)): - for te in ngraph.memlet_tree(e): - if isinstance(te.dst, nodes.NestedSDFG): - self.adjust_arrays_nsdfg(nsdfg, te.dst.sdfg, nname, te.dst_conn, te.data) - if isinstance(te.src, nodes.NestedSDFG): - self.adjust_arrays_nsdfg(nsdfg, te.src.sdfg, nname, te.src_conn, te.data) - - @staticmethod - def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, - graph: dace.sdfg.SDFGState, - intermediate_nodes: Iterable[nodes.AccessNode], - map_entries: List[nodes.MapEntry], - map_exits: List[nodes.MapExit], - do_not_override: List[str] = []): - """ - Checks for all intermediate nodes whether they appear - only within the induced fusible subgraph my map_entries and map_exits. - This is returned as a dict that contains a boolean value for each - intermediate node as a key. - - :param sdfg: SDFG - :param graph: State of interest - :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph - :param map_entries: List of outermost scoped map entries in the subgraph - :param map_exits: List of map exits corresponding to map_entries in order - :param do_not_override: List of data array names not to be compressed - :return: A dictionary indicating for each data string whether its array can be compressed - """ - - # search whether intermediate_nodes appear outside of subgraph - # and store it in dict - data_counter = defaultdict(int) - data_counter_subgraph = defaultdict(int) - - data_intermediate = set([node.data for node in intermediate_nodes]) - - # do a full global search and count each data from each intermediate node - scope_dict = graph.scope_dict() - for state in sdfg.nodes(): - for node in state.nodes(): - if isinstance(node, nodes.AccessNode) and node.data in data_intermediate: - # add them to the counter set in all cases - data_counter[node.data] += 1 - # see whether we are inside the subgraph scope - # if so, add to data_counter_subgraph - # do not add if it is in out_nodes / in_nodes - if state == graph and \ - (node in intermediate_nodes or scope_dict[node] in map_entries): - data_counter_subgraph[node.data] += 1 - - # next up: If intermediate_counter and global counter match and if the array - # is declared transient, it is fully contained by the subgraph - - subgraph_contains_data = {data: data_counter[data] == data_counter_subgraph[data] - and sdfg.data(data).transient - and data not in do_not_override - for data in data_intermediate} - return subgraph_contains_data - - def clone_intermediate_nodes(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - intermediate_nodes: Set[nodes.AccessNode], out_nodes: Set[nodes.AccessNode], - map_entries: List[nodes.MapEntry], map_exits: List[nodes.MapExit]): - """ - Creates cloned access nodes and data arrays for nodes that are both in intermediate nodes - and out nodes, redirecting output from the original node to the cloned node. Operates in-place. - - :param sdfg: SDFG - :param graph: State of interest - :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph - :param out_nodes: List of out nodes appearing in a fusible subgraph - :param map_entries: List of outermost scoped map entries in the subgraph - :param map_exits: List of map exits corresponding to map_entries in order - :return: A dict that maps each intermediate node that also functions as an out node - to the respective cloned transient node - """ - - transients_created = {} - for node in intermediate_nodes & out_nodes: - # create new transient at exit replacing the array - # and redirect all traffic - data_ref = sdfg.data(node.data) - - out_trans_data_name = node.data + '_OUT' - out_trans_data_name = sdfg._find_new_name(out_trans_data_name) - data_trans = sdfg.add_transient(name=out_trans_data_name, - shape=dcpy(data_ref.shape), - dtype=dcpy(data_ref.dtype), - storage=dcpy(data_ref.storage), - offset=dcpy(data_ref.offset)) - node_trans = graph.add_access(out_trans_data_name) - if node.setzero: - node_trans.setzero = True - - # redirect all relevant traffic from node_trans to node - edges = list(graph.out_edges(node)) - for edge in edges: - if edge.dst not in map_entries: - self.copy_edge(graph, edge, new_src=node_trans, remove_old=True) - - graph.add_edge(node, None, node_trans, None, Memlet()) - - transients_created[node] = node_trans - - return transients_created - - def determine_invariant_dimensions(self, sdfg: dace.sdfg.SDFG, graph: dace.sdfg.SDFGState, - intermediate_nodes: Iterable[nodes.AccessNode], - map_entries: List[nodes.MapEntry], - map_exits: List[nodes.MapExit]): - """ - Determines the invariant dimensions for each node -- dimensions in - which the access set of the memlets propagated through map entries and - exits does not change. - - :param sdfg: SDFG - :param graph: State of interest - :param intermediate_nodes: List of intermediate nodes appearing in a fusible subgraph - :param map_entries: List of outermost scoped map entries in the subgraph - :param map_exits: List of map exits corresponding to map_entries in order - :return: A dict mapping each intermediate node (nodes.AccessNode) to a list of integer dimensions - """ - # create dict for every array that for which - # subgraph_contains_data is true that lists invariant axes. - invariant_dimensions = {} - for node in intermediate_nodes: - data = node.data - inv_dims = self.get_invariant_dimensions(sdfg, graph, map_entries, map_exits, node) - if node in invariant_dimensions: - # do a check -- we want the same result for each - # node containing the same data - if not inv_dims == invariant_dimensions[node]: - warnings.warn(f"SubgraphFusion::Data dimensions that are not propagated through differ" - "across multiple instances of access nodes for data {node.data}" - "Please check whether all memlets to AccessNodes containing" - "this data are sound.") - invariant_dimensions[data] |= inv_dims - - else: - invariant_dimensions[data] = inv_dims - - return invariant_dimensions - - def prepare_intermediate_nodes(self, - sdfg: dace.sdfg.SDFG, - graph: dace.sdfg.SDFGState, - in_nodes: Set[nodes.AccessNode], - out_nodes: Set[nodes.AccessNode], - intermediate_nodes: Set[nodes.AccessNode], - map_entries: List[nodes.MapEntry], - map_exits: List[nodes.MapExit], - do_not_override: List[str] = []): - """ - Helper function that computes the following information: - 1. Determine whether intermediate nodes only appear within the induced fusible subgraph. This is equivalent to checking for compresssibility. - 2. Determine whether any intermediate transients are also out nodes, if so they have to be cloned - 3. Determine invariant dimensions for any intermediate transients (that are compressible). - - :return: A tuple (subgraph_contains_data, transients_created, invariant_dimensions) - of dictionaries containing the necessary information - """ - - # 1. Compressibility - subgraph_contains_data = SubgraphFusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes, - map_entries, map_exits, do_not_override) - # 2. Clone intermediate & out transients - transients_created = self.clone_intermediate_nodes(sdfg, graph, intermediate_nodes, out_nodes, map_entries, - map_exits) - # 3. Gather invariant dimensions - invariant_dimensions = self.determine_invariant_dimensions(sdfg, graph, intermediate_nodes, map_entries, - map_exits) - - return (subgraph_contains_data, transients_created, invariant_dimensions) - def apply(self, sdfg, do_not_override=None, **kwargs): """ Apply the SubgraphFusion Transformation. See @fuse for more details """ subgraph = self.subgraph_view(sdfg) @@ -776,7 +790,7 @@ def fuse(self, Arrays that don't exist outside the subgraph get pushed into the map and their data dimension gets cropped. - Otherwise the original array is taken. + Otherwise, the original array is taken. For every output respective connections are crated automatically. @@ -802,7 +816,7 @@ def fuse(self, map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] # See function documentation for an explanation of these variables - node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph, map_entries) + node_config = get_adjacent_nodes(graph, map_entries) (in_nodes, intermediate_nodes, out_nodes) = node_config if self.debug: @@ -826,10 +840,10 @@ def fuse(self, # intermediate_nodes simultaneously # also check which dimensions of each transient data element correspond # to map axes and write this information into a dict. - node_info = self.prepare_intermediate_nodes(sdfg, graph, in_nodes, out_nodes, - intermediate_nodes, - map_entries, map_exits, - do_not_override) + node_info = prepare_intermediate_nodes(sdfg, graph, out_nodes, + intermediate_nodes, + map_entries, map_exits, + do_not_override) (subgraph_contains_data, transients_created, invariant_dimensions) = node_info if self.debug: @@ -868,29 +882,29 @@ def fuse(self, inconnectors_dict[src] = (edge, in_conn, out_conn) # reroute in edge via global_map_entry - self.copy_edge(graph, edge, new_dst=global_map_entry, new_dst_conn=in_conn) + copy_edge(graph, edge, new_dst=global_map_entry, new_dst_conn=in_conn) # map out edges to new map for out_edge in out_edges: - self.copy_edge(graph, out_edge, new_src=global_map_entry, new_src_conn=out_conn) + copy_edge(graph, out_edge, new_src=global_map_entry, new_src_conn=out_conn) else: # connect directly for out_edge in out_edges: mm = dcpy(out_edge.data) - self.copy_edge(graph, out_edge, new_src=src, new_src_conn=None, new_data=mm) + copy_edge(graph, out_edge, new_src=src, new_src_conn=None, new_data=mm) for edge in graph.out_edges(map_entry): # special case: for nodes that have no data connections if not edge.src_conn: - self.copy_edge(graph, edge, new_src=global_map_entry) + copy_edge(graph, edge, new_src=global_map_entry) ###################################### for edge in graph.in_edges(map_exit): if not edge.dst_conn: # no destination connector, path ends here. - self.copy_edge(graph, edge, new_dst=global_map_exit) + copy_edge(graph, edge, new_dst=global_map_exit) continue # find corresponding out_edges for current edge out_edges = [oedge for oedge in graph.out_edges(map_exit) if oedge.src_conn[3:] == edge.dst_conn[2:]] @@ -943,11 +957,11 @@ def fuse(self, # handle separately: intermediate_nodes and pure out nodes # case 1: intermediate_nodes: can just redirect edge if dst in intermediate_nodes: - self.copy_edge(graph, - out_edge, - new_src=edge.src, - new_src_conn=edge.src_conn, - new_data=dcpy(edge.data)) + copy_edge(graph, + out_edge, + new_src=edge.src, + new_src_conn=edge.src_conn, + new_data=dcpy(edge.data)) # case 2: pure out node: connect to outer array node if dst in (out_nodes - intermediate_nodes): @@ -958,7 +972,7 @@ def fuse(self, out_conn = 'OUT_' + next_conn global_map_exit.add_in_connector(in_conn) global_map_exit.add_out_connector(out_conn) - self.copy_edge(graph, edge, new_dst=global_map_exit, new_dst_conn=in_conn) + copy_edge(graph, edge, new_dst=global_map_exit, new_dst_conn=in_conn) port_created = (in_conn, out_conn) else: @@ -977,20 +991,6 @@ def fuse(self, min_offsets = dict() # do one pass to compress all transient arrays - def change_data(transient_array, shape, strides, total_size, offset, lifetime, storage): - if shape is not None: - transient_array.shape = shape - if strides is not None: - transient_array.strides = strides - if total_size is not None: - transient_array.total_size = total_size - if offset is not None: - transient_array.offset = offset - if lifetime is not None: - transient_array.lifetime = lifetime - if storage is not None: - transient_array.storage = storage - data_intermediate = set([node.data for node in intermediate_nodes]) for data_name in data_intermediate: if subgraph_contains_data[data_name] and isinstance(sdfg.data(data_name), dace.data.Array): @@ -1006,8 +1006,7 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s in_edge = next(in_edges_iter) target_subset_curr = dcpy(in_edge.data.subset) target_subset_curr.pop(invariant_dimensions[data_name]) - target_subset = subsets.union(target_subset, \ - target_subset_curr) + target_subset = subsets.union(target_subset, target_subset_curr) except StopIteration: break @@ -1042,7 +1041,6 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s new_data_totalsize = data._prod(new_data_shape) new_data_offset = [0] * len(new_data_shape) - # compress original shape change_data(sdfg.data(data_name), shape=new_data_shape, strides=new_data_strides, @@ -1093,7 +1091,7 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s if isinstance(iedge.src, nodes.NestedSDFG): nsdfg = iedge.src.sdfg nested_data_name = iedge.src_conn - self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name, iedge.data) + adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name, iedge.data) for cedge in out_edges: for edge in graph.memlet_tree(cedge): @@ -1105,7 +1103,7 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s if isinstance(edge.dst, nodes.NestedSDFG): nsdfg = edge.dst.sdfg nested_data_name = edge.dst_conn - self.adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name, edge.data) + adjust_arrays_nsdfg(sdfg, nsdfg, node.data, nested_data_name, edge.data) # if in_edges has several entries: # put other_subset into out_edges for correctness @@ -1255,7 +1253,8 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s warnings.warn(f'{dname}[{in_subset}] may intersect with {dname}[{oe.data.src_subset}]') elif intersect: raise ValueError(f'{dname}[{in_subset}] intersects with {dname}[{oe.data.src_subset}]') - # If the outgoing subset is not covered by the transient data, connect to the outer input node. + # If the outgoing subset is not covered by the transient data, connect to the outer input + # node. if not inode: inode = graph.add_access(dname) graph.add_memlet_path(inode, global_map_entry, oe.dst, memlet=oe.data, dst_conn=oe.dst_conn) From 04a229cefdd3e3acd7dbc0d045fcd6a1d3ff8124 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 14:19:10 +0200 Subject: [PATCH 13/16] Breakout two pieces of `fuse()` function into their own helpers. --- .../subgraph/subgraph_fusion.py | 440 +++++++++--------- 1 file changed, 229 insertions(+), 211 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 5c97c6845f..b02ffb7d10 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -4,16 +4,17 @@ from collections import defaultdict from copy import deepcopy as dcpy from itertools import chain -from typing import List, Tuple, Set, Iterable +from typing import List, Tuple, Set, Iterable, Dict import networkx as nx import dace -from dace import dtypes, symbolic, subsets, data +from dace import dtypes, symbolic, subsets, data, SDFGState from dace.memlet import Memlet from dace.properties import EnumProperty, ListProperty, make_properties, Property from dace.sdfg import nodes, SDFG from dace.sdfg.graph import SubgraphView +from dace.sdfg.nodes import MapExit, MapEntry, AccessNode from dace.sdfg.propagation import _propagate_node from dace.sdfg.utils import consolidate_edges_scope from dace.transformation import transformation @@ -432,6 +433,146 @@ def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, return subgraph_contains_data +def fuse_maps_into_global_map(graph: SDFGState, map_entries: List[MapEntry], map_exits: List[MapExit], + global_map_entry: MapEntry, global_map_exit: MapExit, + in_nodes: Set[AccessNode], intermediate_nodes: Set[AccessNode], + out_nodes: Set[AccessNode], + transients_created: Dict[AccessNode, AccessNode], + invariant_dimensions: Dict[str, Set[int]]): + inconnectors_dict = {} + # Dict for saving incoming nodes and their assigned connectors + # Format: {access_node: (edge, in_conn, out_conn)} + + for map_entry, map_exit in zip(map_entries, map_exits): + # handle inputs + # TODO: dynamic map range -- this is fairly unrealistic in such a setting + for edge in graph.in_edges(map_entry): + src = edge.src + out_edges = [ + e for e in graph.out_edges(map_entry) if (e.src_conn and e.src_conn[3:] == edge.dst_conn[2:]) + ] + + if src in in_nodes: + if src in inconnectors_dict: + # for access nodes only + out_conn = inconnectors_dict[src][2] + + else: + next_conn = global_map_entry.next_connector() + in_conn = 'IN_' + next_conn + out_conn = 'OUT_' + next_conn + global_map_entry.add_in_connector(in_conn) + global_map_entry.add_out_connector(out_conn) + + if isinstance(src, nodes.AccessNode): + inconnectors_dict[src] = (edge, in_conn, out_conn) + + # reroute in edge via global_map_entry + copy_edge(graph, edge, new_dst=global_map_entry, new_dst_conn=in_conn) + + # map out edges to new map + for out_edge in out_edges: + copy_edge(graph, out_edge, new_src=global_map_entry, new_src_conn=out_conn) + + else: + # connect directly + for out_edge in out_edges: + mm = dcpy(out_edge.data) + copy_edge(graph, out_edge, new_src=src, new_src_conn=None, new_data=mm) + + for edge in graph.out_edges(map_entry): + # special case: for nodes that have no data connections + if not edge.src_conn: + copy_edge(graph, edge, new_src=global_map_entry) + + ###################################### + + for edge in graph.in_edges(map_exit): + if not edge.dst_conn: + # no destination connector, path ends here. + copy_edge(graph, edge, new_dst=global_map_exit) + continue + # find corresponding out_edges for current edge + out_edges = [oedge for oedge in graph.out_edges(map_exit) if oedge.src_conn[3:] == edge.dst_conn[2:]] + + # Tuple to store in/out connector port that might be created + port_created = None + + for out_edge in out_edges: + dst = out_edge.dst + + if dst in intermediate_nodes & out_nodes: + + # create connection through global map from + # dst to dst_transient that was created + dst_transient = transients_created[dst] + next_conn = global_map_exit.next_connector() + in_conn = 'IN_' + next_conn + out_conn = 'OUT_' + next_conn + global_map_exit.add_in_connector(in_conn) + global_map_exit.add_out_connector(out_conn) + + # for each transient created, create a union + # of outgoing memlets' subsets. this is + # a cheap fix to override assignments in invariant + # dimensions + union = None + for oe in graph.out_edges(transients_created[dst]): + union = subsets.union(union, oe.data.subset) + if isinstance(union, subsets.Indices): + union = subsets.Range.from_indices(union) + inner_memlet = dcpy(edge.data) + for i, s in enumerate(edge.data.subset): + if i in invariant_dimensions[dst.label]: + inner_memlet.subset[i] = union[i] + + inner_memlet.other_subset = dcpy(inner_memlet.subset) + + e_inner = graph.add_edge(dst, None, global_map_exit, in_conn, inner_memlet) + + outer_memlet = dcpy(out_edge.data) + e_outer = graph.add_edge(global_map_exit, out_conn, dst_transient, None, outer_memlet) + + # remove edge from dst to dst_transient that was created + # in intermediate preparation. + for e in graph.out_edges(dst): + if e.dst == dst_transient: + graph.remove_edge(e) + break + + # handle separately: intermediate_nodes and pure out nodes + # case 1: intermediate_nodes: can just redirect edge + if dst in intermediate_nodes: + copy_edge(graph, + out_edge, + new_src=edge.src, + new_src_conn=edge.src_conn, + new_data=dcpy(edge.data)) + + # case 2: pure out node: connect to outer array node + if dst in (out_nodes - intermediate_nodes): + if edge.dst != global_map_exit: + next_conn = global_map_exit.next_connector() + + in_conn = 'IN_' + next_conn + out_conn = 'OUT_' + next_conn + global_map_exit.add_in_connector(in_conn) + global_map_exit.add_out_connector(out_conn) + copy_edge(graph, edge, new_dst=global_map_exit, new_dst_conn=in_conn) + port_created = (in_conn, out_conn) + + else: + in_conn, out_conn = port_created + + # map + graph.add_edge(global_map_exit, out_conn, dst, out_edge.dst_conn, dcpy(out_edge.data)) + + # maps are now ready to be discarded + # all connected edges will be finally removed as well + graph.remove_node(map_entry) + graph.remove_node(map_exit) + + def change_data(transient_array: dace.data.Array, shape, strides, total_size, offset, lifetime, storage): """Compress original shape""" if shape is not None: @@ -448,6 +589,85 @@ def change_data(transient_array: dace.data.Array, shape, strides, total_size, of transient_array.storage = storage +def compress_transient_arrays(sdfg: SDFG, graph: SDFGState, transient_allocation: bool, + subgraph_contains_data: Dict[str, bool], intermediate_nodes: Set[AccessNode], + invariant_dimensions: Dict[str, Set[int]]): + """Do one pass to compress all transient arrays.""" + + # create a mapping from data arrays to offsets + # for later memlet adjustments later + min_offsets = dict() + + data_intermediate = set([node.data for node in intermediate_nodes]) + for data_name in data_intermediate: + if subgraph_contains_data[data_name] and isinstance(sdfg.data(data_name), dace.data.Array): + all_nodes = [n for n in intermediate_nodes if n.data == data_name] + in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes))) + + in_edges_iter = iter(in_edges) + in_edge = next(in_edges_iter) + target_subset = dcpy(in_edge.data.subset) + target_subset.pop(invariant_dimensions[data_name]) + while True: + try: # executed if there are multiple in_edges + in_edge = next(in_edges_iter) + target_subset_curr = dcpy(in_edge.data.subset) + target_subset_curr.pop(invariant_dimensions[data_name]) + target_subset = subsets.union(target_subset, target_subset_curr) + except StopIteration: + break + + min_offsets_cropped = target_subset.min_element_approx() + # calculate the new transient array size. + target_subset.offset(min_offsets_cropped, True) + + # re-add invariant dimensions with offset 0 and save to min_offsets + min_offset = [] + index = 0 + for i in range(len(sdfg.data(data_name).shape)): + if i in invariant_dimensions[data_name]: + min_offset.append(0) + else: + min_offset.append(min_offsets_cropped[index]) + index += 1 + + min_offsets[data_name] = min_offset + + # determine the shape of the new array. + new_data_shape = [] + index = 0 + for i, sz in enumerate(sdfg.data(data_name).shape): + if i in invariant_dimensions[data_name]: + new_data_shape.append(sz) + else: + new_data_shape.append(target_subset.size()[index]) + index += 1 + + new_data_strides = [data._prod(new_data_shape[i + 1:]) for i in range(len(new_data_shape))] + + new_data_totalsize = data._prod(new_data_shape) + new_data_offset = [0] * len(new_data_shape) + + change_data(sdfg.data(data_name), + shape=new_data_shape, + strides=new_data_strides, + total_size=new_data_totalsize, + offset=new_data_offset, + lifetime=dtypes.AllocationLifetime.Scope, + storage=transient_allocation) + + else: + # don't modify data container - array is needed outside + # of subgraph. + + # hack: set lifetime to State if allocation has only been + # scope so far to avoid allocation issues + if sdfg.data(data_name).lifetime == dtypes.AllocationLifetime.Scope: + sdfg.data(data_name).lifetime = dtypes.AllocationLifetime.State + + return min_offsets + + @make_properties class SubgraphFusion(transformation.SubgraphTransformation): """ @@ -816,8 +1036,7 @@ def fuse(self, map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] # See function documentation for an explanation of these variables - node_config = get_adjacent_nodes(graph, map_entries) - (in_nodes, intermediate_nodes, out_nodes) = node_config + in_nodes, intermediate_nodes, out_nodes = get_adjacent_nodes(graph, map_entries) if self.debug: print("SubgraphFusion::In_nodes", in_nodes) @@ -845,218 +1064,17 @@ def fuse(self, map_entries, map_exits, do_not_override) - (subgraph_contains_data, transients_created, invariant_dimensions) = node_info + subgraph_contains_data, transients_created, invariant_dimensions = node_info if self.debug: print("SubgraphFusion:: {Intermediate_node: subgraph_contains_data} dict") print(subgraph_contains_data) - inconnectors_dict = {} - # Dict for saving incoming nodes and their assigned connectors - # Format: {access_node: (edge, in_conn, out_conn)} - - for map_entry, map_exit in zip(map_entries, map_exits): - # handle inputs - # TODO: dynamic map range -- this is fairly unrealistic in such a setting - for edge in graph.in_edges(map_entry): - src = edge.src - out_edges = [ - e for e in graph.out_edges(map_entry) if (e.src_conn and e.src_conn[3:] == edge.dst_conn[2:]) - ] - - if src in in_nodes: - in_conn = None - out_conn = None - if src in inconnectors_dict: - # for access nodes only - in_conn = inconnectors_dict[src][1] - out_conn = inconnectors_dict[src][2] - - else: - next_conn = global_map_entry.next_connector() - in_conn = 'IN_' + next_conn - out_conn = 'OUT_' + next_conn - global_map_entry.add_in_connector(in_conn) - global_map_entry.add_out_connector(out_conn) - - if isinstance(src, nodes.AccessNode): - inconnectors_dict[src] = (edge, in_conn, out_conn) - - # reroute in edge via global_map_entry - copy_edge(graph, edge, new_dst=global_map_entry, new_dst_conn=in_conn) - - # map out edges to new map - for out_edge in out_edges: - copy_edge(graph, out_edge, new_src=global_map_entry, new_src_conn=out_conn) - - else: - # connect directly - for out_edge in out_edges: - mm = dcpy(out_edge.data) - copy_edge(graph, out_edge, new_src=src, new_src_conn=None, new_data=mm) - - for edge in graph.out_edges(map_entry): - # special case: for nodes that have no data connections - if not edge.src_conn: - copy_edge(graph, edge, new_src=global_map_entry) - - ###################################### - - for edge in graph.in_edges(map_exit): - if not edge.dst_conn: - # no destination connector, path ends here. - copy_edge(graph, edge, new_dst=global_map_exit) - continue - # find corresponding out_edges for current edge - out_edges = [oedge for oedge in graph.out_edges(map_exit) if oedge.src_conn[3:] == edge.dst_conn[2:]] - - # Tuple to store in/out connector port that might be created - port_created = None - - for out_edge in out_edges: - dst = out_edge.dst - - if dst in intermediate_nodes & out_nodes: - - # create connection through global map from - # dst to dst_transient that was created - dst_transient = transients_created[dst] - next_conn = global_map_exit.next_connector() - in_conn = 'IN_' + next_conn - out_conn = 'OUT_' + next_conn - global_map_exit.add_in_connector(in_conn) - global_map_exit.add_out_connector(out_conn) - - # for each transient created, create a union - # of outgoing memlets' subsets. this is - # a cheap fix to override assignments in invariant - # dimensions - union = None - for oe in graph.out_edges(transients_created[dst]): - union = subsets.union(union, oe.data.subset) - if isinstance(union, subsets.Indices): - union = subsets.Range.from_indices(union) - inner_memlet = dcpy(edge.data) - for i, s in enumerate(edge.data.subset): - if i in invariant_dimensions[dst.label]: - inner_memlet.subset[i] = union[i] - - inner_memlet.other_subset = dcpy(inner_memlet.subset) - - e_inner = graph.add_edge(dst, None, global_map_exit, in_conn, inner_memlet) - - outer_memlet = dcpy(out_edge.data) - e_outer = graph.add_edge(global_map_exit, out_conn, dst_transient, None, outer_memlet) - - # remove edge from dst to dst_transient that was created - # in intermediate preparation. - for e in graph.out_edges(dst): - if e.dst == dst_transient: - graph.remove_edge(e) - break - - # handle separately: intermediate_nodes and pure out nodes - # case 1: intermediate_nodes: can just redirect edge - if dst in intermediate_nodes: - copy_edge(graph, - out_edge, - new_src=edge.src, - new_src_conn=edge.src_conn, - new_data=dcpy(edge.data)) - - # case 2: pure out node: connect to outer array node - if dst in (out_nodes - intermediate_nodes): - if edge.dst != global_map_exit: - next_conn = global_map_exit.next_connector() - - in_conn = 'IN_' + next_conn - out_conn = 'OUT_' + next_conn - global_map_exit.add_in_connector(in_conn) - global_map_exit.add_out_connector(out_conn) - copy_edge(graph, edge, new_dst=global_map_exit, new_dst_conn=in_conn) - port_created = (in_conn, out_conn) - - else: - in_conn, out_conn = port_created - - # map - graph.add_edge(global_map_exit, out_conn, dst, out_edge.dst_conn, dcpy(out_edge.data)) - - # maps are now ready to be discarded - # all connected edges will be finally removed as well - graph.remove_node(map_entry) - graph.remove_node(map_exit) - - # create a mapping from data arrays to offsets - # for later memlet adjustments later - min_offsets = dict() - - # do one pass to compress all transient arrays - data_intermediate = set([node.data for node in intermediate_nodes]) - for data_name in data_intermediate: - if subgraph_contains_data[data_name] and isinstance(sdfg.data(data_name), dace.data.Array): - all_nodes = [n for n in intermediate_nodes if n.data == data_name] - in_edges = list(chain(*(graph.in_edges(n) for n in all_nodes))) - - in_edges_iter = iter(in_edges) - in_edge = next(in_edges_iter) - target_subset = dcpy(in_edge.data.subset) - target_subset.pop(invariant_dimensions[data_name]) - while True: - try: # executed if there are multiple in_edges - in_edge = next(in_edges_iter) - target_subset_curr = dcpy(in_edge.data.subset) - target_subset_curr.pop(invariant_dimensions[data_name]) - target_subset = subsets.union(target_subset, target_subset_curr) - except StopIteration: - break - - min_offsets_cropped = target_subset.min_element_approx() - # calculate the new transient array size. - target_subset.offset(min_offsets_cropped, True) - - # re-add invariant dimensions with offset 0 and save to min_offsets - min_offset = [] - index = 0 - for i in range(len(sdfg.data(data_name).shape)): - if i in invariant_dimensions[data_name]: - min_offset.append(0) - else: - min_offset.append(min_offsets_cropped[index]) - index += 1 - - min_offsets[data_name] = min_offset - - # determine the shape of the new array. - new_data_shape = [] - index = 0 - for i, sz in enumerate(sdfg.data(data_name).shape): - if i in invariant_dimensions[data_name]: - new_data_shape.append(sz) - else: - new_data_shape.append(target_subset.size()[index]) - index += 1 - - new_data_strides = [data._prod(new_data_shape[i + 1:]) for i in range(len(new_data_shape))] - - new_data_totalsize = data._prod(new_data_shape) - new_data_offset = [0] * len(new_data_shape) - - change_data(sdfg.data(data_name), - shape=new_data_shape, - strides=new_data_strides, - total_size=new_data_totalsize, - offset=new_data_offset, - lifetime=dtypes.AllocationLifetime.Scope, - storage=self.transient_allocation) - - else: - # don't modify data container - array is needed outside - # of subgraph. + fuse_maps_into_global_map(graph, map_entries, map_exits, global_map_entry, global_map_exit, + in_nodes, intermediate_nodes, out_nodes, transients_created, invariant_dimensions) + sdfg.validate() - # hack: set lifetime to State if allocation has only been - # scope so far to avoid allocation issues - if sdfg.data(data_name).lifetime == dtypes.AllocationLifetime.Scope: - sdfg.data(data_name).lifetime = dtypes.AllocationLifetime.State + min_offsets = compress_transient_arrays(sdfg, graph, self.transient_allocation, subgraph_contains_data, + intermediate_nodes, invariant_dimensions) # do one pass to adjust strides and the memlets of in-between transients for node in intermediate_nodes: From f355a6b7e271c7fb93e2b762e7e950e3a9060f19 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 14:36:09 +0200 Subject: [PATCH 14/16] Finally, fix the bug in subgraph fusion that was causing index error. --- dace/transformation/subgraph/subgraph_fusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index b02ffb7d10..a03feeeb07 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -606,7 +606,7 @@ def compress_transient_arrays(sdfg: SDFG, graph: SDFGState, transient_allocation in_edges_iter = iter(in_edges) in_edge = next(in_edges_iter) - target_subset = dcpy(in_edge.data.subset) + target_subset = dcpy(in_edge.data.dst_subset) target_subset.pop(invariant_dimensions[data_name]) while True: try: # executed if there are multiple in_edges @@ -957,7 +957,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: for ie in graph.in_edges(e.src): # get corresponding inner memlet and join its subset to our access set if ie.dst_conn[2:] == e.src_conn[3:]: - current_subset = dcpy(ie.data.subset) + current_subset = dcpy(ie.data.dst_subset) current_subset.pop(invariant_dimensions[node_data]) access_set = subsets.union(access_set, current_subset) From 490b4156cdacbe1f2f6d85980f182043bdae67ee Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 15:14:52 +0200 Subject: [PATCH 15/16] Forgot to update the references to subgraph fusion, fixing now + Removing unnecessary imports. --- .../transformation/subgraph/stencil_tiling.py | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index c9c3e9afd4..f6c60eea8f 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -2,34 +2,27 @@ """ This module contains classes and functions that implement the orthogonal stencil tiling transformation. """ -import math +import itertools +import warnings +from collections import defaultdict +from copy import deepcopy as dcpy import dace -from dace import dtypes, registry, symbolic +import dace.subsets as subsets +import dace.symbolic as symbolic +from dace import dtypes from dace.properties import make_properties, Property, ShapeProperty from dace.sdfg import nodes -from dace.transformation import transformation from dace.sdfg.propagation import _propagate_node - -from dace.transformation.dataflow.map_for_loop import MapToForLoop -from dace.transformation.dataflow.map_expansion import MapExpansion +from dace.transformation import transformation from dace.transformation.dataflow.map_collapse import MapCollapse +from dace.transformation.dataflow.map_expansion import MapExpansion +from dace.transformation.dataflow.map_for_loop import MapToForLoop from dace.transformation.dataflow.strip_mining import StripMining -from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.interstate.loop_detection import DetectLoop -from dace.transformation.subgraph import SubgraphFusion - -from copy import deepcopy as dcpy - -import dace.subsets as subsets -import dace.symbolic as symbolic - -import itertools -import warnings - -from collections import defaultdict - +from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.subgraph import helpers +from dace.transformation.subgraph import subgraph_fusion @make_properties @@ -51,7 +44,7 @@ class StencilTiling(transformation.SubgraphTransformation): prefix = Property(dtype=str, default="stencil", desc="Prefix for new inner tiled range symbols") - strides = ShapeProperty(dtype=tuple, default=(1, ), desc="Tile stride") + strides = ShapeProperty(dtype=tuple, default=(1,), desc="Tile stride") schedule = Property(dtype=dace.dtypes.ScheduleType, default=dace.dtypes.ScheduleType.Default, @@ -200,13 +193,13 @@ def can_be_applied(sdfg, subgraph) -> bool: # get intermediate_nodes, out_nodes from SubgraphFusion Transformation try: - node_config = get_adjacent_nodes(sdfg, graph, map_entries) + node_config = subgraph_fusion.get_adjacent_nodes(graph, map_entries) (_, intermediate_nodes, out_nodes) = node_config except NotImplementedError: return False # 1.4: check topological feasibility - if not check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes): + if not subgraph_fusion.check_topo_feasibility(graph, map_entries, intermediate_nodes, out_nodes): return False # 1.5 nodes that are both intermediate and out nodes # are not supported in StencilTiling @@ -215,8 +208,8 @@ def can_be_applied(sdfg, subgraph) -> bool: # 1.6 check that we only deal with compressible transients - subgraph_contains_data = determine_compressible_nodes(sdfg, graph, intermediate_nodes, - map_entries, map_exits) + subgraph_contains_data = subgraph_fusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes, + map_entries, map_exits) if any([s == False for s in subgraph_contains_data.values()]): return False @@ -264,8 +257,8 @@ def can_be_applied(sdfg, subgraph) -> bool: for i, (p_subset, c_subset) in enumerate(zip(parent_coverage, children_coverage)): # transform into subset - p_subset = subsets.Range((p_subset, )) - c_subset = subsets.Range((c_subset, )) + p_subset = subsets.Range((p_subset,)) + c_subset = subsets.Range((c_subset,)) # get associated parameter in memlet params1 = symbolic.symlist(memlets[map_entry][1][data_name][i]).keys() @@ -292,7 +285,7 @@ def can_be_applied(sdfg, subgraph) -> bool: except KeyError: return False - #parameter mapping must be the same + # parameter mapping must be the same if param_parent_coverage != param_children_coverage: return False @@ -394,7 +387,7 @@ def apply(self, sdfg): for data_name, ranges in local_ranges.items(): for param, r in zip(variable_mapping[data_name], ranges): # create new range from this subset and assign - rng = subsets.Range((r, )) + rng = subsets.Range((r,)) if param: inferred_ranges[map_entry][param] = subsets.union(inferred_ranges[map_entry][param], rng) @@ -457,9 +450,9 @@ def apply(self, sdfg): reference_range_current = self.reference_range[param] min_diff = symbolic.SymExpr(reference_range_current.min_element()[0] \ - - target_range_current.min_element()[0]) + - target_range_current.min_element()[0]) max_diff = symbolic.SymExpr(target_range_current.max_element()[0] \ - - reference_range_current.max_element()[0]) + - reference_range_current.max_element()[0]) try: min_diff = symbolic.evaluate(min_diff, {}) From 8ce705596c29f7e4d7806ed125754df82ab9883f Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 16 Oct 2024 15:51:17 +0200 Subject: [PATCH 16/16] Validating halfway in the fusion wasn't a good idea. --- dace/transformation/subgraph/subgraph_fusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index a03feeeb07..c686e7c66e 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -1071,7 +1071,6 @@ def fuse(self, fuse_maps_into_global_map(graph, map_entries, map_exits, global_map_entry, global_map_exit, in_nodes, intermediate_nodes, out_nodes, transients_created, invariant_dimensions) - sdfg.validate() min_offsets = compress_transient_arrays(sdfg, graph, self.transient_allocation, subgraph_contains_data, intermediate_nodes, invariant_dimensions)