From 036e4073f8662ea407f1939bd415ac7934a26c24 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 11:55:58 +0200 Subject: [PATCH 001/108] Add data dependency analyses --- .../analysis/writeset_underapproximation.py | 311 +++++++++--------- dace/transformation/pass_pipeline.py | 3 +- .../passes/analysis/__init__.py | 1 + .../passes/{ => analysis}/analysis.py | 125 +++++-- .../analysis/control_flow_region_analysis.py | 229 +++++++++++++ .../passes/analysis/loop_analysis.py | 213 ++++++++++++ 6 files changed, 691 insertions(+), 191 deletions(-) create mode 100644 dace/transformation/passes/analysis/__init__.py rename dace/transformation/passes/{ => analysis}/analysis.py (83%) create mode 100644 dace/transformation/passes/analysis/control_flow_region_analysis.py create mode 100644 dace/transformation/passes/analysis/loop_analysis.py diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index bfd5f4cb00..afc34add30 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -8,35 +8,22 @@ import copy import itertools import warnings -from typing import Any, Dict, List, Set, Tuple, Type, Union +from typing import Any, Dict, List, Set, Tuple, Type, TypedDict, Union import sympy import dace +from dace.sdfg.state import LoopRegion +from dace.transformation import transformation from dace.symbolic import issymbolic, pystr_to_symbolic, simplify from dace.transformation.pass_pipeline import Modifies, Pass from dace import registry, subsets, symbolic, dtypes, data, SDFG, Memlet from dace.sdfg.nodes import NestedSDFG, AccessNode from dace.sdfg import nodes, SDFGState, graph as gr -from dace.sdfg.analysis import cfg +from dace.sdfg.analysis import cfg as cfg_analysis from dace.transformation import pass_pipeline as ppl from dace.sdfg import graph from dace.sdfg import scope - -# dictionary mapping each edge to a copy of the memlet of that edge with its write set -# underapproximated -approximation_dict: Dict[graph.Edge, Memlet] = {} -# dictionary that maps loop headers to "border memlets" that are written to in the -# corresponding loop -loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] = {} -# dictionary containing information about the for loops in the SDFG -loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, - List[SDFGState], str, subsets.Range]] = {} -# dictionary mapping each nested SDFG to the iteration variables surrounding it -iteration_variables: Dict[SDFG, Set[str]] = {} -# dictionary mapping each state to the iteration variables surrounding it -# (including the ones from surrounding SDFGs) -ranges_per_state: Dict[SDFGState, - Dict[str, subsets.Range]] = defaultdict(lambda: {}) +from dace.transformation.passes.analysis import loop_analysis @registry.make_registry @@ -417,7 +404,7 @@ def _find_unconditionally_executed_states(sdfg: SDFG) -> Set[SDFGState]: sdfg.add_edge(sink_node, dummy_sink, dace.sdfg.InterstateEdge()) # get all the nodes that are executed unconditionally in the state-machine a.k.a nodes # that dominate the sink states - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) states = dominators[dummy_sink] # remove dummy state sdfg.remove_node(dummy_sink) @@ -689,21 +676,34 @@ def _merge_subsets(subset_a: subsets.Subset, subset_b: subsets.Subset) -> subset return subset_b +class UnderapproximateWritesDictT(TypedDict): + approximation: Dict[graph.Edge, Memlet] + loop_approximation: Dict[SDFGState, Dict[str, Memlet]] + loops: Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] + + +@transformation.experimental_cfg_block_compatible class UnderapproximateWrites(ppl.Pass): + # Dictionary mapping each edge to a copy of the memlet of that edge with its write set underapproximated. + approximation_dict: Dict[graph.Edge, Memlet] = {} + # Dictionary that maps loop headers to "border memlets" that are written to in the corresponding loop. + loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] = {} + # Dictionary containing information about the for loops in the SDFG. + loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] = {} + # Dictionary mapping each nested SDFG to the iteration variables surrounding it. + iteration_variables: Dict[SDFG, Set[str]] = {} + # Mapping of state to the iteration variables surrounding them, including the ones from surrounding SDFGs. + ranges_per_state: Dict[SDFGState, Dict[str, subsets.Range]] = defaultdict(lambda: {}) + def modifies(self) -> Modifies: - return ppl.Modifies.Everything + return ppl.Modifies.States def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - - def apply_pass( - self, sdfg: dace.SDFG, pipeline_results: Dict[str, Any] - ) -> Dict[str, Union[ - Dict[graph.Edge, Memlet], - Dict[SDFGState, Dict[str, Memlet]], - Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]]]]: + # If anything was modified, reapply. + return modified & ppl.Modifies.Everything + + def apply_pass(self, top_sdfg: dace.SDFG, _) -> Dict[int, UnderapproximateWritesDictT]: """ Applies the pass to the given SDFG. @@ -725,42 +725,49 @@ def apply_pass( :notes: The only modification this pass performs on the SDFG is splitting interstate edges. """ - # clear the global dictionaries - approximation_dict.clear() - loop_write_dict.clear() - loop_dict.clear() - iteration_variables.clear() - ranges_per_state.clear() - - # fill the approximation dictionary with the original edges as keys and the edges with the - # approximated memlets as values - for (edge, parent) in sdfg.all_edges_recursive(): - if isinstance(parent, SDFGState): - approximation_dict[edge] = copy.deepcopy(edge.data) - if not isinstance(approximation_dict[edge].subset, - subsets.SubsetUnion) and approximation_dict[edge].subset: - approximation_dict[edge].subset = subsets.SubsetUnion( - [approximation_dict[edge].subset]) - if not isinstance(approximation_dict[edge].dst_subset, - subsets.SubsetUnion) and approximation_dict[edge].dst_subset: - approximation_dict[edge].dst_subset = subsets.SubsetUnion( - [approximation_dict[edge].dst_subset]) - if not isinstance(approximation_dict[edge].src_subset, - subsets.SubsetUnion) and approximation_dict[edge].src_subset: - approximation_dict[edge].src_subset = subsets.SubsetUnion( - [approximation_dict[edge].src_subset]) - - self._underapproximate_writes_sdfg(sdfg) - - # Replace None with empty SubsetUnion in each Memlet - for entry in approximation_dict.values(): - if entry.subset is None: - entry.subset = subsets.SubsetUnion([]) - return { - "approximation": approximation_dict, - "loop_approximation": loop_write_dict, - "loops": loop_dict - } + result = defaultdict(lambda: {'approximation': dict(), 'loop_approximation': dict(), 'loops': dict()}) + + for sdfg in top_sdfg.all_sdfgs_recursive(): + # Clear the global dictionaries. + self.approximation_dict.clear() + self.loop_write_dict.clear() + self.loop_dict.clear() + self.iteration_variables.clear() + self.ranges_per_state.clear() + + # fill the approximation dictionary with the original edges as keys and the edges with the + # approximated memlets as values + for (edge, parent) in sdfg.all_edges_recursive(): + if isinstance(parent, SDFGState): + self.approximation_dict[edge] = copy.deepcopy(edge.data) + if not isinstance(self.approximation_dict[edge].subset, + subsets.SubsetUnion) and self.approximation_dict[edge].subset: + self.approximation_dict[edge].subset = subsets.SubsetUnion([ + self.approximation_dict[edge].subset + ]) + if not isinstance(self.approximation_dict[edge].dst_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].dst_subset: + self.approximation_dict[edge].dst_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].dst_subset + ]) + if not isinstance(self.approximation_dict[edge].src_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].src_subset: + self.approximation_dict[edge].src_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].src_subset + ]) + + self._underapproximate_writes_sdfg(sdfg) + + # Replace None with empty SubsetUnion in each Memlet + for entry in self.approximation_dict.values(): + if entry.subset is None: + entry.subset = subsets.SubsetUnion([]) + + result[sdfg.cfg_id]['approximation'] = self.approximation_dict + result[sdfg.cfg_id]['loop_approximation'] = self.loop_write_dict + result[sdfg.cfg_id]['loops'] = self.loop_dict + + return result def _underapproximate_writes_sdfg(self, sdfg: SDFG): """ @@ -770,10 +777,18 @@ def _underapproximate_writes_sdfg(self, sdfg: SDFG): split_interstate_edges(sdfg) loops = self._find_for_loops(sdfg) - loop_dict.update(loops) + self.loop_dict.update(loops) + + for region in sdfg.all_control_flow_regions(): + if isinstance(region, LoopRegion): + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + for state in region.all_states(): + self.ranges_per_state[state][region.loop_variable] = subsets.Range([(start, stop, stride)]) - for state in sdfg.nodes(): - self._underapproximate_writes_state(sdfg, state) + for state in region.all_states(): + self._underapproximate_writes_state(sdfg, state) self._underapproximate_writes_loops(loops, sdfg) @@ -885,13 +900,12 @@ def _find_for_loops(self, sources=[begin], condition=lambda _, child: child != guard) - if itvar not in ranges_per_state[begin]: + if itvar not in self.ranges_per_state[begin]: for loop_state in loop_states: - ranges_per_state[loop_state][itervar] = subsets.Range([ - rng]) + self.ranges_per_state[loop_state][itervar] = subsets.Range([rng]) loop_state_list.append(loop_state) - ranges_per_state[guard][itervar] = subsets.Range([rng]) + self.ranges_per_state[guard][itervar] = subsets.Range([rng]) identified_loops[guard] = (begin, last_loop_state, loop_state_list, itvar, subsets.Range([rng])) @@ -934,8 +948,11 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # approximation_dict # First, propagate nested SDFGs in a bottom-up fashion + dnodes: Set[nodes.AccessNode] = set() for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): + if isinstance(node, AccessNode): + dnodes.add(node) + elif isinstance(node, nodes.NestedSDFG): self._find_live_iteration_variables(node, sdfg, state) # Propagate memlets inside the nested SDFG. @@ -947,6 +964,15 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # Process scopes from the leaves upwards self._underapproximate_writes_scope(sdfg, state, state.scope_leaves()) + # Make sure any scalar writes are also added if they have not been processed yet. + for dn in dnodes: + desc = sdfg.data(dn.data) + if isinstance(desc, data.Scalar) or (isinstance(desc, data.Array) and desc.total_size == 1): + for iedge in state.in_edges(dn): + if not iedge in self.approximation_dict: + self.approximation_dict[iedge] = copy.deepcopy(iedge.data) + self.approximation_dict[iedge]._edge = iedge + def _find_live_iteration_variables(self, nsdfg: nodes.NestedSDFG, sdfg: SDFG, @@ -963,15 +989,14 @@ def symbol_map(mapping, symbol): return None map_iteration_variables = _collect_iteration_variables(state, nsdfg) - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - state_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + state_iteration_variables = self.ranges_per_state[state].keys() iteration_variables_local = (map_iteration_variables | sdfg_iteration_variables | state_iteration_variables) mapped_iteration_variables = set( map(lambda x: symbol_map(nsdfg.symbol_mapping, x), iteration_variables_local)) if mapped_iteration_variables: - iteration_variables[nsdfg.sdfg] = mapped_iteration_variables + self.iteration_variables[nsdfg.sdfg] = mapped_iteration_variables def _underapproximate_writes_nested_sdfg( self, @@ -1025,12 +1050,11 @@ def _init_border_memlet(template_memlet: Memlet, # Collect all memlets belonging to this access node memlets = [] for edge in edges: - inside_memlet = approximation_dict[edge] + inside_memlet = self.approximation_dict[edge] memlets.append(inside_memlet) # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - inside_memlet, node.label) + border_memlet = _init_border_memlet(inside_memlet, node.label) # Given all of this access nodes' memlets union all the subsets to one SubsetUnion if len(memlets) > 0: @@ -1042,18 +1066,16 @@ def _init_border_memlet(template_memlet: Memlet, border_memlet.subset, subset) # collect the memlets for each loop in the NSDFG - if state in loop_write_dict: - for node_label, loop_memlet in loop_write_dict[state].items(): + if state in self.loop_write_dict: + for node_label, loop_memlet in self.loop_write_dict[state].items(): if node_label not in border_memlets: continue border_memlet = border_memlets[node_label] # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - loop_memlet, node_label) + border_memlet = _init_border_memlet(loop_memlet, node_label) # compute the union of the ranges to merge the subsets. - border_memlet.subset = _merge_subsets( - border_memlet.subset, loop_memlet.subset) + border_memlet.subset = _merge_subsets(border_memlet.subset, loop_memlet.subset) # Make sure any potential NSDFG symbol mapping is correctly reversed # when propagating out. @@ -1068,17 +1090,16 @@ def _init_border_memlet(template_memlet: Memlet, # Propagate the inside 'border' memlets outside the SDFG by # offsetting, and unsqueezing if necessary. for edge in parent_state.out_edges(nsdfg_node): - out_memlet = approximation_dict[edge] + out_memlet = self.approximation_dict[edge] if edge.src_conn in border_memlets: internal_memlet = border_memlets[edge.src_conn] if internal_memlet is None: out_memlet.subset = None out_memlet.dst_subset = None - approximation_dict[edge] = out_memlet + self.approximation_dict[edge] = out_memlet continue - out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, - nsdfg_node) - approximation_dict[edge] = out_memlet + out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, nsdfg_node) + self.approximation_dict[edge] = out_memlet def _underapproximate_writes_loop(self, sdfg: SDFG, @@ -1099,9 +1120,7 @@ def _underapproximate_writes_loop(self, propagate_memlet_loop will be called recursively on the outermost loopheaders """ - def _init_border_memlet(template_memlet: Memlet, - node_label: str - ): + def _init_border_memlet(template_memlet: Memlet, node_label: str): ''' Creates a Memlet with the same data as the template_memlet, stores it in the border_memlets dictionary and returns it. @@ -1111,8 +1130,7 @@ def _init_border_memlet(template_memlet: Memlet, border_memlets[node_label] = border_memlet return border_memlet - def filter_subsets(itvar: str, itrange: subsets.Range, - memlet: Memlet) -> List[subsets.Subset]: + def filter_subsets(itvar: str, itrange: subsets.Range, memlet: Memlet) -> List[subsets.Subset]: # helper method that filters out subsets that do not depend on the iteration variable # if the iteration range is symbolic @@ -1134,7 +1152,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, if rng.num_elements() == 0: return # make sure there is no break out of the loop - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) if any(begin not in dominators[s] and not begin is s for s in loop_states): return border_memlets = defaultdict(None) @@ -1159,7 +1177,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, # collect all the subsets of the incoming memlets for the current access node for edge in edges: - inside_memlet = copy.copy(approximation_dict[edge]) + inside_memlet = copy.copy(self.approximation_dict[edge]) # filter out subsets that could become empty depending on assignments # of symbols filtered_subsets = filter_subsets( @@ -1177,35 +1195,27 @@ def filter_subsets(itvar: str, itrange: subsets.Range, self._underapproximate_writes_loop_subset(sdfg, memlets, border_memlet, sdfg.arrays[node.label], itvar, rng) - if state not in loop_write_dict: + if state not in self.loop_write_dict: continue # propagate the border memlets of nested loop - for node_label, other_border_memlet in loop_write_dict[state].items(): + for node_label, other_border_memlet in self.loop_write_dict[state].items(): # filter out subsets that could become empty depending on symbol assignments - filtered_subsets = filter_subsets( - itvar, rng, other_border_memlet) + filtered_subsets = filter_subsets(itvar, rng, other_border_memlet) if not filtered_subsets: continue - other_border_memlet.subset = subsets.SubsetUnion( - filtered_subsets) + other_border_memlet.subset = subsets.SubsetUnion(filtered_subsets) border_memlet = border_memlets.get(node_label) if border_memlet is None: - border_memlet = _init_border_memlet( - other_border_memlet, node_label) + border_memlet = _init_border_memlet(other_border_memlet, node_label) self._underapproximate_writes_loop_subset(sdfg, [other_border_memlet], border_memlet, sdfg.arrays[node_label], itvar, rng) - loop_write_dict[loop_header] = border_memlets + self.loop_write_dict[loop_header] = border_memlets - def _underapproximate_writes_loop_subset(self, - sdfg: dace.SDFG, - memlets: List[Memlet], - dst_memlet: Memlet, - arr: dace.data.Array, - itvar: str, - rng: subsets.Subset, + def _underapproximate_writes_loop_subset(self, sdfg: dace.SDFG, memlets: List[Memlet], dst_memlet: Memlet, + arr: dace.data.Array, itvar: str, rng: subsets.Subset, loop_nest_itvars: Union[Set[str], None] = None): """ Helper function that takes a list of (border) memlets, propagates them out of a @@ -1223,16 +1233,11 @@ def _underapproximate_writes_loop_subset(self, if len(memlets) > 0: params = [itvar] # get all the other iteration variables surrounding this memlet - surrounding_itvars = iteration_variables[sdfg] if sdfg in iteration_variables else set( - ) + surrounding_itvars = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() if loop_nest_itvars: surrounding_itvars |= loop_nest_itvars - subset = self._underapproximate_subsets(memlets, - arr, - params, - rng, - use_dst=True, + subset = self._underapproximate_subsets(memlets, arr, params, rng, use_dst=True, surrounding_itvars=surrounding_itvars).subset if subset is None or len(subset.subset_list) == 0: @@ -1240,9 +1245,7 @@ def _underapproximate_writes_loop_subset(self, # compute the union of the ranges to merge the subsets. dst_memlet.subset = _merge_subsets(dst_memlet.subset, subset) - def _underapproximate_writes_scope(self, - sdfg: SDFG, - state: SDFGState, + def _underapproximate_writes_scope(self, sdfg: SDFG, state: SDFGState, scopes: Union[scope.ScopeTree, List[scope.ScopeTree]]): """ Propagate memlets from the given scopes outwards. @@ -1253,8 +1256,7 @@ def _underapproximate_writes_scope(self, """ # for each map scope find the iteration variables of surrounding maps - surrounding_map_vars: Dict[scope.ScopeTree, - Set[str]] = _collect_itvars_scope(scopes) + surrounding_map_vars: Dict[scope.ScopeTree, Set[str]] = _collect_itvars_scope(scopes) if isinstance(scopes, scope.ScopeTree): scopes_to_process = [scopes] else: @@ -1272,8 +1274,7 @@ def _underapproximate_writes_scope(self, sdfg, state, surrounding_map_vars) - self._underapproximate_writes_node( - state, scope_node.exit, surrounding_iteration_variables) + self._underapproximate_writes_node(state, scope_node.exit, surrounding_iteration_variables) # Add parent to next frontier next_scopes.add(scope_node.parent) scopes_to_process = next_scopes @@ -1286,9 +1287,8 @@ def _collect_iteration_variables_scope_node(self, surrounding_map_vars: Dict[scope.ScopeTree, Set[str]]) -> Set[str]: map_iteration_variables = surrounding_map_vars[ scope_node] if scope_node in surrounding_map_vars else set() - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - loop_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + loop_iteration_variables = self.ranges_per_state[state].keys() surrounding_iteration_variables = (map_iteration_variables | sdfg_iteration_variables | loop_iteration_variables) @@ -1308,12 +1308,8 @@ def _underapproximate_writes_node(self, :param surrounding_itvars: Iteration variables that surround the map scope """ if isinstance(node, nodes.EntryNode): - internal_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] - external_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] + internal_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] + external_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] def geticonn(e): return e.src_conn[4:] @@ -1323,12 +1319,8 @@ def geteconn(e): use_dst = False else: - internal_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] - external_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] + internal_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] + external_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] def geticonn(e): return e.dst_conn[3:] @@ -1339,21 +1331,17 @@ def geteconn(e): use_dst = True for edge in external_edges: - if approximation_dict[edge].is_empty(): + if self.approximation_dict[edge].is_empty(): new_memlet = Memlet() else: internal_edge = next( e for e in internal_edges if geticonn(e) == geteconn(edge)) - aligned_memlet = self._align_memlet( - dfg_state, internal_edge, dst=use_dst) - new_memlet = self._underapproximate_memlets(dfg_state, - aligned_memlet, - node, - True, - connector=geteconn( - edge), + aligned_memlet = self._align_memlet(dfg_state, internal_edge, dst=use_dst) + new_memlet = self._underapproximate_memlets(dfg_state, aligned_memlet, node, True, + connector=geteconn(edge), surrounding_itvars=surrounding_itvars) - approximation_dict[edge] = new_memlet + new_memlet._edge = edge + self.approximation_dict[edge] = new_memlet def _align_memlet(self, state: SDFGState, @@ -1373,16 +1361,16 @@ def _align_memlet(self, is_src = edge.data._is_data_src # Memlet is already aligned if is_src is None or (is_src and not dst) or (not is_src and dst): - res = approximation_dict[edge] + res = self.approximation_dict[edge] return res # Data<->Code memlets always have one data container mpath = state.memlet_path(edge) if not isinstance(mpath[0].src, AccessNode) or not isinstance(mpath[-1].dst, AccessNode): - return approximation_dict[edge] + return self.approximation_dict[edge] # Otherwise, find other data container - result = copy.deepcopy(approximation_dict[edge]) + result = copy.deepcopy(self.approximation_dict[edge]) if dst: node = mpath[-1].dst else: @@ -1390,8 +1378,8 @@ def _align_memlet(self, # Fix memlet fields result.data = node.data - result.subset = approximation_dict[edge].other_subset - result.other_subset = approximation_dict[edge].subset + result.subset = self.approximation_dict[edge].other_subset + result.other_subset = self.approximation_dict[edge].subset result._is_data_src = not is_src return result @@ -1448,9 +1436,9 @@ def _underapproximate_memlets(self, # and union their subsets if union_inner_edges: aggdata = [ - approximation_dict[e] + self.approximation_dict[e] for e in neighboring_edges - if approximation_dict[e].data == memlet.data and approximation_dict[e] != memlet + if self.approximation_dict[e].data == memlet.data and self.approximation_dict[e] != memlet ] else: aggdata = [] @@ -1459,8 +1447,7 @@ def _underapproximate_memlets(self, if arr is None: if memlet.data not in sdfg.arrays: - raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % - memlet.data) + raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % memlet.data) # FIXME: A memlet alone (without an edge) cannot figure out whether it is data<->data or data<->code # so this test cannot be used diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 494f9c39ae..0da8a96165 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -22,6 +22,7 @@ class Modifies(Flag): Symbols = auto() #: Symbols were modified States = auto() #: The number of SDFG states and their connectivity (not their contents) were modified InterstateEdges = auto() #: Contents (conditions/assignments) or existence of inter-state edges were modified + CFG = States | InterstateEdges #: A CFG (any level) was modified (connectivity or number of control flow blocks, but not their contents) AccessNodes = auto() #: Access nodes' existence or properties were modified Scopes = auto() #: Scopes (e.g., Map, Consume, Pipeline) or associated properties were created/removed/modified Tasklets = auto() #: Tasklets were created/removed or their contents were modified @@ -29,7 +30,7 @@ class Modifies(Flag): Memlets = auto() #: Memlets' existence, contents, or properties were modified Nodes = AccessNodes | Scopes | Tasklets | NestedSDFGs #: Modification of any dataflow node (contained in an SDFG state) was made Edges = InterstateEdges | Memlets #: Any edge (memlet or inter-state) was modified - Everything = Descriptors | Symbols | States | InterstateEdges | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) + Everything = Descriptors | Symbols | CFG | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) @properties.make_properties diff --git a/dace/transformation/passes/analysis/__init__.py b/dace/transformation/passes/analysis/__init__.py new file mode 100644 index 0000000000..5bc1f6e3f3 --- /dev/null +++ b/dace/transformation/passes/analysis/__init__.py @@ -0,0 +1 @@ +from .analysis import * diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis/analysis.py similarity index 83% rename from dace/transformation/passes/analysis.py rename to dace/transformation/passes/analysis/analysis.py index c8bb0b7a9c..b230425d00 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -1,7 +1,8 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict -from dace.transformation import pass_pipeline as ppl +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd @@ -16,6 +17,7 @@ @properties.make_properties +@transformation.experimental_cfg_block_compatible class StateReachability(ppl.Pass): """ Evaluates state reachability (which other states can be executed after each state). @@ -28,25 +30,84 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply - return modified & ppl.Modifies.States - - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: + return modified & ppl.Modifies.CFG + + def depends_on(self) -> Set[ppl.Pass | ppl.Pass]: + return {ControlFlowBlockReachability} + + def _region_closure(self, region: ControlFlowRegion, + block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: + closure: Set[SDFGState] = set() + if isinstance(region, LoopRegion): + # Any point inside the loop may reach any other point inside the loop again. + # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. + closure.update(region.all_states()) + + # Add all states that this region can reach in its parent graph to the closure. + for reached_block in block_reach[region.parent_graph.cfg_id][region]: + if isinstance(reached_block, ControlFlowRegion): + closure.update(reached_block.all_states()) + elif isinstance(reached_block, SDFGState): + closure.add(reached_block) + + # Walk up the parent tree. + pivot = region.parent_graph + while pivot and not isinstance(pivot, SDFG): + closure.update(self._region_closure(pivot, block_reach)) + pivot = pivot.parent_graph + return closure + + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. """ + # Ensure control flow block reachability is run if not run within a pipeline. + if not ControlFlowBlockReachability.__name__ in pipeline_res: + cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) + else: + cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Set[SDFGState]] = {} + result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) + for state in sdfg.states(): + for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: + if isinstance(reached, ControlFlowRegion): + result[state].update(reached.all_states()) + elif isinstance(reached, SDFGState): + result[state].add(reached) + if state.parent_graph is not sdfg: + result[state].update(self._region_closure(state.parent_graph, cf_block_reach_dict)) + reachable[sdfg.cfg_id] = result + return reachable - # In networkx this is currently implemented naively for directed graphs. - # The implementation below is faster - # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for n, v in reachable_nodes(sdfg.nx): - result[n] = set(v) +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowBlockReachability(ppl.Pass): + """ + Evaluates control flow block reachability (which control flow block can be executed after each control flow block) + """ - reachable[sdfg.cfg_id] = result + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: + """ + :return: For each control flow region, a dictionary mapping each control flow block to its other reachable + control flow blocks in the same region. + """ + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict(lambda: defaultdict(set)) + for cfg in top_sdfg.all_control_flow_regions(recursive=True): + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + for n, v in reachable_nodes(cfg.nx): + reachable[cfg.cfg_id][n] = set(v) return reachable @@ -99,6 +160,7 @@ def reachable_nodes(G): @properties.make_properties +@transformation.experimental_cfg_block_compatible class SymbolAccessSets(ppl.Pass): """ Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). @@ -116,25 +178,27 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each state and interstate edge to a tuple of its (read, written) symbols. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): - readset = state.free_symbols - # No symbols may be written to inside states. - result[state] = (readset, set()) - for oedge in sdfg.out_edges(state): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[sdfg.cfg_id] = result + for cfg in sdfg.all_control_flow_regions(): + adesc = set(sdfg.arrays.keys()) + result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} + for block in cfg.nodes(): + if isinstance(block, SDFGState): + # No symbols may be written to inside states. + result[block] = (block.free_symbols, set()) + for oedge in cfg.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + top_result[cfg.cfg_id] = result return top_result @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessSets(ppl.Pass): """ Evaluates memory access sets (which arrays/data descriptors are read/written in each state). @@ -179,6 +243,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessStates(ppl.Pass): """ For each data descriptor, creates a set of states in which access nodes of that data are used. @@ -201,13 +266,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Set[SDFGState]] = defaultdict(set) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): result[anode.data].add(state) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): fsyms = e.data.free_symbols & anames for access in fsyms: result[access].update({e.src, e.dst}) @@ -217,6 +282,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessNodes(ppl.Pass): """ For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given @@ -242,7 +308,7 @@ def apply_pass(self, top_sdfg: SDFG, for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict( lambda: defaultdict(lambda: [set(), set()])) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): if state.in_degree(anode) > 0: result[anode.data][state][1].add(anode) @@ -508,6 +574,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessRanges(ppl.Pass): """ For each data descriptor, finds all memlets used to access it (read/write ranges). @@ -544,6 +611,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindReferenceSources(ppl.Pass): """ For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used @@ -586,6 +654,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, @properties.make_properties +@transformation.experimental_cfg_block_compatible class DeriveSDFGConstraints(ppl.Pass): CATEGORY: str = 'Analysis' diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py new file mode 100644 index 0000000000..e11aa945a8 --- /dev/null +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -0,0 +1,229 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from collections import defaultdict +from typing import Any, Dict, List, Set, Tuple + +import networkx as nx + +from dace import SDFG, SDFGState +from dace import data as dt +from dace import properties +from dace.memlet import Memlet +from dace.sdfg import nodes, propagation +from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDictT +from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.scope import ScopeTree +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.subsets import Range +from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.passes.analysis import AccessRanges, ControlFlowBlockReachability + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class StateDataDependence(ppl.Pass): + """ + Analyze the input dependencies and the underapproximated outputs of states. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.Nodes | ppl.Modifies.Memlets) + + def depends_on(self): + return {UnderapproximateWrites, AccessRanges} + + def _gather_reads_scope(self, state: SDFGState, scope: ScopeTree, + writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]], + not_covered_reads: Set[Memlet], scope_ranges: Dict[str, Range]): + scope_nodes = state.scope_children()[scope.entry] + data_nodes_in_scope: Set[nodes.AccessNode] = set([n for n in scope_nodes if isinstance(nodes.AccessNode)]) + if scope.entry is not None: + # propagate + pass + + for anode in data_nodes_in_scope: + for oedge in state.out_edges(anode): + if not oedge.data.is_empty(): + root_edge = state.memlet_tree(oedge).root().edge + read_subset = root_edge.data.src_subset + covered = False + for [write, to] in writes[anode.data]: + if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): + covered = True + break + if not covered: + not_covered_reads.add(root_edge.data) + + def _state_get_deps(self, state: SDFGState, + underapproximated_writes: UnderapproximateWritesDictT) -> Tuple[Set[Memlet], Set[Memlet]]: + # Collect underapproximated write memlets. + writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]] = defaultdict(lambda: []) + for anode in state.data_nodes(): + for iedge in state.in_edges(anode): + if not iedge.data.is_empty(): + root_edge = state.memlet_tree(iedge).root().edge + if root_edge in underapproximated_writes['approximation']: + writes[anode.data].append([underapproximated_writes['approximation'][root_edge], anode]) + else: + writes[anode.data].append([root_edge.data, anode]) + + # Go over (overapproximated) reads and check if they are covered by writes. + not_covered_reads: List[Tuple[MultiConnectorEdge[Memlet], Memlet]] = [] + for anode in state.data_nodes(): + for oedge in state.out_edges(anode): + if not oedge.data.is_empty(): + if oedge.data.data != anode.data: + # Special case for memlets copying data out of the scope, which are by default aligned with the + # outside data container. In this case, the source container must either be a scalar, or the + # read subset is contained in the memlet's `other_subset` property. + # See `dace.sdfg.propagation.align_memlet` for more. + desc = state.sdfg.data(anode.data) + if oedge.data.other_subset is not None: + read_subset = oedge.data.other_subset + elif isinstance(desc, dt.Scalar) or (isinstance(desc, dt.Array) and desc.total_size == 1): + read_subset = Range([(0, 0, 1)] * len(desc.shape)) + else: + raise RuntimeError('Invalid memlet range detected in StateDataDependence analysis') + else: + read_subset = oedge.data.src_subset or oedge.data.subset + covered = False + for [write, to] in writes[anode.data]: + if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): + covered = True + break + if not covered: + #root_edge = state.memlet_tree(oedge).root().edge + #not_covered_reads.append([root_edge, root_edge.data]) + not_covered_reads.append([oedge, oedge.data]) + # Make sure all reads are propagated if they happen inside maps. We do not need to do this for writes, because + # it is already taken care of by the write underapproximation analysis pass. + self._recursive_propagate_reads(state, state.scope_tree()[None], not_covered_reads) + + write_set = set() + for data in writes: + for memlet, _ in writes[data]: + write_set.add(memlet) + + read_set = set() + for reads in not_covered_reads: + read_set.add(reads[1]) + + return read_set, write_set + + def _recursive_propagate_reads(self, state: SDFGState, scope: ScopeTree, + read_edges: Set[Tuple[MultiConnectorEdge[Memlet], Memlet]]): + for child in scope.children: + self._recursive_propagate_reads(state, child, read_edges) + + if scope.entry is not None: + if isinstance(scope.entry, nodes.MapEntry): + for read_tuple in read_edges: + read_edge, read_memlet = read_tuple + for param in scope.entry.map.params: + if param in read_memlet.free_symbols: + aligned_memlet = propagation.align_memlet(state, read_edge, True) + propagated_memlet = propagation.propagate_memlet(state, aligned_memlet, scope.entry, True) + read_tuple[1] = propagated_memlet + + def apply_pass(self, top_sdfg: SDFG, + pipeline_results: Dict[str, Any]) -> Dict[int, Dict[SDFGState, Tuple[Set[Memlet], Set[Memlet]]]]: + """ + :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. + """ + + results = defaultdict(lambda: defaultdict(lambda: [set(), set()])) + + underapprox_writes_dict: Dict[int, Any] = pipeline_results[UnderapproximateWrites.__name__] + for sdfg in top_sdfg.all_sdfgs_recursive(): + uapprox_writes = underapprox_writes_dict[sdfg.cfg_id] + for state in sdfg.states(): + input_dependencies, output_dependencies = self._state_get_deps(state, uapprox_writes) + results[sdfg.cfg_id][state] = [input_dependencies, output_dependencies] + + return results + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class CFGDataDependence(ppl.Pass): + """ + Analyze the input dependencies and the underapproximated outputs of control flow graphs / regions. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def depends_on(self): + return {StateDataDependence, ControlFlowBlockReachability} + + def _recursive_get_deps_region(self, cfg: ControlFlowRegion, + results: Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]], + state_deps: Dict[int, Dict[SDFGState, Tuple[Set[Memlet], Set[Memlet]]]], + cfg_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] + ) -> Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]: + # Collect all individual reads and writes happening inside the region. + region_reads: Dict[str, List[Tuple[Memlet, ControlFlowBlock]]] = defaultdict(list) + region_writes: Dict[str, List[Tuple[Memlet, ControlFlowBlock]]] = defaultdict(list) + for node in cfg.nodes(): + if isinstance(node, SDFGState): + for read in state_deps[node.sdfg.cfg_id][node][0]: + region_reads[read.data].append([read, node]) + for write in state_deps[node.sdfg.cfg_id][node][1]: + region_writes[write.data].append([write, node]) + elif isinstance(node, ControlFlowRegion): + sub_reads, sub_writes = self._recursive_get_deps_region(node, results, state_deps, cfg_reach) + for data in sub_reads: + for read in sub_reads[data]: + region_reads[data].append([read, node]) + for data in sub_writes: + for write in sub_writes[data]: + region_writes[data].append([write, node]) + + # Through reachability analysis, check which writes cover which reads. + # TODO: make sure this doesn't cover up reads if we have a cycle in the CFG. + not_covered_reads: Dict[str, Set[Memlet]] = defaultdict(set) + for data in region_reads: + for read, read_block in region_reads[data]: + covered = False + for write, write_block in region_writes[data]: + if (write.subset.covers_precise(read.src_subset or read.subset) and + write_block is not read_block and + nx.has_path(cfg.nx, write_block, read_block)): + covered = True + break + if not covered: + not_covered_reads[data].add(read) + + write_set: Dict[str, Set[Memlet]] = defaultdict(set) + for data in region_writes: + for memlet, _ in region_writes[data]: + write_set[data].add(memlet) + + results[cfg.cfg_id] = [not_covered_reads, write_set] + + return not_covered_reads, write_set + + def apply_pass(self, top_sdfg: SDFG, + pipeline_res: Dict[str, Any]) -> Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]]: + """ + :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. + """ + + results = defaultdict(lambda: defaultdict(lambda: [defaultdict(set), defaultdict(set)])) + + state_deps_dict = pipeline_res[StateDataDependence.__name__] + cfb_reachability_dict = pipeline_res[ControlFlowBlockReachability.__name__] + for sdfg in top_sdfg.all_sdfgs_recursive(): + self._recursive_get_deps_region(sdfg, results, state_deps_dict, cfb_reachability_dict) + + return results diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py new file mode 100644 index 0000000000..293021de9c --- /dev/null +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -0,0 +1,213 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from collections import defaultdict +from typing import Any, Dict, Optional, Set, Tuple + +import sympy + +from dace import SDFG, properties, symbolic, transformation +from dace.memlet import Memlet +from dace.sdfg.state import LoopRegion +from dace.subsets import Range, SubsetUnion +from dace.transformation import pass_pipeline as ppl +from dace.transformation.pass_pipeline import Pass +from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class LoopCarryDependencyAnalysis(ppl.Pass): + """ + Analyze the data dependencies between loop iterations for loop regions. + """ + + CATEGORY: str = 'Analysis' + + _non_analyzable_loops: Set[LoopRegion] + + def __init__(self): + self._non_analyzable_loops = set() + super().__init__() + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def depends_on(self) -> Set[type[Pass] | Pass]: + return {CFGDataDependence} + + def _intersects(self, loop: LoopRegion, write_subset: Range, read_subset: Range, update: sympy.Basic) -> bool: + """ + Check if a write subset intersects a read subset after being offset by the loop stride. The offset is performed + based on the symbolic loop update assignment expression. + """ + offset = update - symbolic.symbol(loop.loop_variable) + offset_list = [] + for i in range(write_subset.dims()): + if loop.loop_variable in write_subset.get_free_symbols_by_indices([i]): + offset_list.append(offset) + else: + offset_list.append(0) + offset_write = write_subset.offset_new(offset_list, True) + return offset_write.intersects(read_subset) + + def apply_pass(self, top_sdfg: SDFG, + pipeline_results: Dict[str, Any]) -> Dict[int, Dict[LoopRegion, Dict[Memlet, Set[Memlet]]]]: + """ + :return: For each SDFG, a dictionary mapping loop regions to a dictionary that resolves reads to writes in the + same loop, from which they may carry a RAW dependency. + """ + results = defaultdict(lambda: defaultdict(dict)) + + cfg_dependency_dict: Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]] = pipeline_results[ + CFGDataDependence.__name__ + ] + for cfg in top_sdfg.all_control_flow_regions(recursive=True): + if isinstance(cfg, LoopRegion): + loop_inputs, loop_outputs = cfg_dependency_dict[cfg.cfg_id] + update_assignment = None + loop_dependencies: Dict[Memlet, Set[Memlet]] = dict() + + for data in loop_inputs: + if not data in loop_outputs: + continue + + for input in loop_inputs[data]: + read_subset = input.src_subset or input.subset + dep_candidates: Set[Memlet] = set() + if cfg.loop_variable and cfg.loop_variable in input.free_symbols: + # If the iteration variable is involved in an access, we need to first offset it by the loop + # stride and then check for an overlap/intersection. If one is found after offsetting, there + # is a RAW loop carry dependency. + for output in loop_outputs[data]: + # Get and cache the update assignment for the loop. + if update_assignment is None and not cfg in self._non_analyzable_loops: + update_assignment = get_update_assignment(cfg) + if update_assignment is None: + self._non_analyzable_loops(cfg) + + if isinstance(output.subset, SubsetUnion): + if any([self._intersects(cfg, s, read_subset, update_assignment) + for s in output.subset.subset_list]): + dep_candidates.add(output) + elif self._intersects(cfg, output.subset, read_subset, update_assignment): + dep_candidates.add(output) + else: + # Check for basic overlaps/intersections in RAW loop carry dependencies, when there is no + # iteration variable involved. + for output in loop_outputs[data]: + if isinstance(output.subset, SubsetUnion): + if any([s.intersects(read_subset) for s in output.subset.subset_list]): + dep_candidates.add(output) + elif output.subset.intersects(read_subset): + dep_candidates.add(output) + loop_dependencies[input] = dep_candidates + results[cfg.sdfg.cfg_id][cfg] = loop_dependencies + + return results + + +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = ast.unparse(node.value) + return self.generic_visit(node) + + +def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). + """ + end: Optional[symbolic.SymbolicType] = None + a = sympy.Wild('a') + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + itersym = symbolic.pystr_to_symbolic(loop.loop_variable) + match = condition.match(itersym < a) + if match: + end = match[a] - 1 + if end is None: + match = condition.match(itersym <= a) + if match: + end = match[a] + if end is None: + match = condition.match(itersym > a) + if match: + end = match[a] + 1 + if end is None: + match = condition.match(itersym >= a) + if match: + end = match[a] + return end + + +def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's init statement to identify the exact init assignment expression. + """ + init_stmt = loop.init_statement + if init_stmt is None: + return None + + init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] + assignments: Dict[str, str] = {} + for code in init_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's update statement to identify the exact update assignment expression. + """ + update_stmt = loop.update_statement + if update_stmt is None: + return None + + update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + assignments: Dict[str, str] = {} + for code in update_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + update_assignment = get_update_assignment(loop) + if update_assignment: + return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) + return None From 4aa13eda11d990185cc73bf7595f847f713e0d4b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 13:53:02 +0200 Subject: [PATCH 002/108] Fix type --- dace/transformation/passes/analysis/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index b230425d00..d0fc8decdc 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -32,7 +32,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.CFG - def depends_on(self) -> Set[ppl.Pass | ppl.Pass]: + def depends_on(self): return {ControlFlowBlockReachability} def _region_closure(self, region: ControlFlowRegion, From a228f34a8d1e448c31ae97bfa15f6c7de3a5a535 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 14:28:14 +0200 Subject: [PATCH 003/108] Fix types --- .../analysis/writeset_underapproximation.py | 30 ++++++++++++------- .../passes/analysis/analysis.py | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index afc34add30..3dcdbf3473 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -4,25 +4,33 @@ an SDFG. """ -from collections import defaultdict import copy import itertools +import sys import warnings -from typing import Any, Dict, List, Set, Tuple, Type, TypedDict, Union +from collections import defaultdict +from typing import Dict, List, Set, Tuple, Union + +if sys.version >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + import sympy import dace +from dace import SDFG, Memlet, data, dtypes, registry, subsets, symbolic +from dace.sdfg import SDFGState +from dace.sdfg import graph +from dace.sdfg import graph as gr +from dace.sdfg import nodes, scope +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.nodes import AccessNode, NestedSDFG from dace.sdfg.state import LoopRegion -from dace.transformation import transformation from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -from dace.transformation.pass_pipeline import Modifies, Pass -from dace import registry, subsets, symbolic, dtypes, data, SDFG, Memlet -from dace.sdfg.nodes import NestedSDFG, AccessNode -from dace.sdfg import nodes, SDFGState, graph as gr -from dace.sdfg.analysis import cfg as cfg_analysis from dace.transformation import pass_pipeline as ppl -from dace.sdfg import graph -from dace.sdfg import scope +from dace.transformation import transformation +from dace.transformation.pass_pipeline import Modifies from dace.transformation.passes.analysis import loop_analysis @@ -807,8 +815,8 @@ def _find_for_loops(self, """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop # dictionary mapping loop headers to beginstate, loopstates, looprange identified_loops = {} diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index d0fc8decdc..1a4ab01b88 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -62,7 +62,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS :return: A dictionary mapping each state to its other reachable states. """ # Ensure control flow block reachability is run if not run within a pipeline. - if not ControlFlowBlockReachability.__name__ in pipeline_res: + if pipeline_res is None or not ControlFlowBlockReachability.__name__ in pipeline_res: cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) else: cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] From 10c3b6c74ec8074cd30a74aace2614d197b21e77 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 16:56:30 +0200 Subject: [PATCH 004/108] Update tests --- .../analysis/writeset_underapproximation.py | 4 +- .../analysis/control_flow_region_analysis.py | 4 +- .../analysis/control_flow_region_analysis.py | 80 +++++++++++++++++++ 3 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 tests/passes/analysis/control_flow_region_analysis.py diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index 3dcdbf3473..c4a685e62a 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -11,7 +11,7 @@ from collections import defaultdict from typing import Dict, List, Set, Tuple, Union -if sys.version >= (3, 8): +if sys.version_info >= (3, 8): from typing import TypedDict else: from typing_extensions import TypedDict @@ -31,7 +31,6 @@ from dace.transformation import pass_pipeline as ppl from dace.transformation import transformation from dace.transformation.pass_pipeline import Modifies -from dace.transformation.passes.analysis import loop_analysis @registry.make_registry @@ -782,6 +781,7 @@ def _underapproximate_writes_sdfg(self, sdfg: SDFG): Underapproximates write-sets of loops, maps and nested SDFGs in the given SDFG. """ from dace.transformation.helpers import split_interstate_edges + from dace.transformation.passes.analysis import loop_analysis split_interstate_edges(sdfg) loops = self._find_for_loops(sdfg) diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py index e11aa945a8..92b2badecf 100644 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -16,7 +16,7 @@ from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.subsets import Range from dace.transformation import pass_pipeline as ppl, transformation -from dace.transformation.passes.analysis import AccessRanges, ControlFlowBlockReachability +from dace.transformation.passes.analysis.analysis import AccessRanges, ControlFlowBlockReachability @properties.make_properties @@ -97,8 +97,6 @@ def _state_get_deps(self, state: SDFGState, covered = True break if not covered: - #root_edge = state.memlet_tree(oedge).root().edge - #not_covered_reads.append([root_edge, root_edge.data]) not_covered_reads.append([oedge, oedge.data]) # Make sure all reads are propagated if they happen inside maps. We do not need to do this for writes, because # it is already taken care of by the write underapproximation analysis pass. diff --git a/tests/passes/analysis/control_flow_region_analysis.py b/tests/passes/analysis/control_flow_region_analysis.py new file mode 100644 index 0000000000..64461edd85 --- /dev/null +++ b/tests/passes/analysis/control_flow_region_analysis.py @@ -0,0 +1,80 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests analysis passes related to control flow regions (control_flow_region_analysis.py). """ + + +import dace +from dace.memlet import Memlet +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import LoopRegion, SDFGState +from dace.transformation.pass_pipeline import Pipeline +from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence, StateDataDependence + + +def test_simple_state_data_dependence_with_self_contained_read(): + N = dace.symbol('N') + + @dace.program + def myprog(A: dace.float64[N], B: dace.float64): + for i in dace.map[0:N/2]: + with dace.tasklet: + in1 << B[i] + out1 >> A[i] + out1 = in1 + 1 + with dace.tasklet: + in1 << B[i] + out1 >> B[N - (i + 1)] + out1 = in1 - 1 + for i in dace.map[0:N/2]: + with dace.tasklet: + in1 << A[i] + out1 >> B[i] + out1 = in1 * 2 + + sdfg = myprog.to_sdfg() + + res = {} + Pipeline([StateDataDependence()]).apply_pass(sdfg, res) + state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] + + assert len(state_data_deps[0]) == 1 + read_memlet: Memlet = list(state_data_deps[0])[0] + assert read_memlet.data == 'B' + assert read_memlet.subset[0][0] == 0 + assert read_memlet.subset[0][1] == 0.5 * N - 1 or read_memlet.subset[0][1] == N / 2 - 1 + + assert len(state_data_deps[1]) == 3 + + +''' +def test_nested_cf_region_data_dependence(): + N = dace.symbol('N') + + @dace.program + def myprog(A: dace.float64[N], B: dace.float64): + for i in range(N): + with dace.tasklet: + in1 << B[i] + out1 >> A[i] + out1 = in1 + 1 + for i in range(N): + with dace.tasklet: + in1 << A[i] + out1 >> B[i] + out1 = in1 * 2 + + myprog.use_experimental_cfg_blocks = True + + sdfg = myprog.to_sdfg() + + res = {} + pipeline = Pipeline([CFGDataDependence()]) + pipeline.__experimental_cfg_block_compatible__ = True + pipeline.apply_pass(sdfg, res) + + print(sdfg) + ''' + + +if __name__ == '__main__': + test_simple_state_data_dependence_with_self_contained_read() + #test_nested_cf_region_data_dependence() From 77ca17f4db1984a2adc7da7440ec8cc340c16543 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Sep 2024 17:51:50 +0200 Subject: [PATCH 005/108] Fixes --- dace/frontend/python/parser.py | 2 + dace/transformation/helpers.py | 4 +- .../passes/analysis/analysis.py | 80 ++++++++++++------- .../passes/analysis/loop_analysis.py | 2 +- ...y => control_flow_region_analysis_test.py} | 0 5 files changed, 56 insertions(+), 32 deletions(-) rename tests/passes/analysis/{control_flow_region_analysis.py => control_flow_region_analysis_test.py} (100%) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index e55829933c..e0900c749b 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -498,6 +498,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdutils.inline_control_flow_regions(sdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + sdfg.reset_cfg_list() + # Apply simplification pass automatically if not cached and (simplify == True or (simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))): diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 0d583236cb..6c17538a37 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -379,7 +379,7 @@ def nest_state_subgraph(sdfg: SDFG, SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ - if state.parent != sdfg: + if state.sdfg != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') @@ -433,7 +433,7 @@ def nest_state_subgraph(sdfg: SDFG, # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG - other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() + other_nodes = set(n.data for s in sdfg.states() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 1a4ab01b88..095319f807 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -35,28 +35,6 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {ControlFlowBlockReachability} - def _region_closure(self, region: ControlFlowRegion, - block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: - closure: Set[SDFGState] = set() - if isinstance(region, LoopRegion): - # Any point inside the loop may reach any other point inside the loop again. - # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. - closure.update(region.all_states()) - - # Add all states that this region can reach in its parent graph to the closure. - for reached_block in block_reach[region.parent_graph.cfg_id][region]: - if isinstance(reached_block, ControlFlowRegion): - closure.update(reached_block.all_states()) - elif isinstance(reached_block, SDFGState): - closure.add(reached_block) - - # Walk up the parent tree. - pivot = region.parent_graph - while pivot and not isinstance(pivot, SDFG): - closure.update(self._region_closure(pivot, block_reach)) - pivot = pivot.parent_graph - return closure - def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. @@ -71,12 +49,8 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) for state in sdfg.states(): for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: - if isinstance(reached, ControlFlowRegion): - result[state].update(reached.all_states()) - elif isinstance(reached, SDFGState): + if isinstance(reached, SDFGState): result[state].add(reached) - if state.parent_graph is not sdfg: - result[state].update(self._region_closure(state.parent_graph, cf_block_reach_dict)) reachable[sdfg.cfg_id] = result return reachable @@ -90,24 +64,72 @@ class ControlFlowBlockReachability(ppl.Pass): CATEGORY: str = 'Analysis' + contain_to_single_level = properties.Property(dtype=bool, default=False) + + def __init__(self, contain_to_single_level=False) -> None: + super().__init__() + + self.contain_to_single_level = contain_to_single_level + def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG + def _region_closure(self, region: ControlFlowRegion, + block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: + closure: Set[SDFGState] = set() + if isinstance(region, LoopRegion): + # Any point inside the loop may reach any other point inside the loop again. + # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. + closure.update(region.all_control_flow_blocks()) + + # Add all states that this region can reach in its parent graph to the closure. + for reached_block in block_reach[region.parent_graph.cfg_id][region]: + if isinstance(reached_block, ControlFlowRegion): + closure.update(reached_block.all_control_flow_blocks()) + closure.add(reached_block) + + # Walk up the parent tree. + pivot = region.parent_graph + while pivot and not isinstance(pivot, SDFG): + closure.update(self._region_closure(pivot, block_reach)) + pivot = pivot.parent_graph + return closure + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: """ :return: For each control flow region, a dictionary mapping each control flow block to its other reachable control flow blocks in the same region. """ - reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict(lambda: defaultdict(set)) + single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( + lambda: defaultdict(set) + ) for cfg in top_sdfg.all_control_flow_regions(recursive=True): # In networkx this is currently implemented naively for directed graphs. # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for n, v in reachable_nodes(cfg.nx): - reachable[cfg.cfg_id][n] = set(v) + single_level_reachable[cfg.cfg_id][n] = set(v) + if isinstance(cfg, LoopRegion): + single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) + + if self.contain_to_single_level: + return single_level_reachable + + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = {} + for sdfg in top_sdfg.all_sdfgs_recursive(): + for cfg in sdfg.all_control_flow_regions(): + result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) + for block in cfg.nodes(): + for reached in single_level_reachable[block.parent_graph.cfg_id][block]: + if isinstance(reached, ControlFlowRegion): + result[block].update(reached.all_control_flow_blocks()) + result[block].add(reached) + if block.parent_graph is not sdfg: + result[block].update(self._region_closure(block.parent_graph, single_level_reachable)) + reachable[cfg.cfg_id] = result return reachable diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 293021de9c..dd8c5f7446 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -36,7 +36,7 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG - def depends_on(self) -> Set[type[Pass] | Pass]: + def depends_on(self): return {CFGDataDependence} def _intersects(self, loop: LoopRegion, write_subset: Range, read_subset: Range, update: sympy.Basic) -> bool: diff --git a/tests/passes/analysis/control_flow_region_analysis.py b/tests/passes/analysis/control_flow_region_analysis_test.py similarity index 100% rename from tests/passes/analysis/control_flow_region_analysis.py rename to tests/passes/analysis/control_flow_region_analysis_test.py From 4e6035d68e41ef33a4e06253bb472af662fe2e2f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 17 Sep 2024 14:03:19 +0200 Subject: [PATCH 006/108] Fix tests --- .../analysis/writeset_underapproximation.py | 13 ++- dace/sdfg/propagation.py | 23 ++--- .../analysis/control_flow_region_analysis.py | 2 +- .../control_flow_region_analysis_test.py | 45 ++++----- .../writeset_underapproximation_test.py | 94 ++++++++++++------- 5 files changed, 103 insertions(+), 74 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index c4a685e62a..e1b88f9401 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -1,7 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ -Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in -an SDFG. +Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in an SDFG. """ import copy @@ -736,11 +735,11 @@ def apply_pass(self, top_sdfg: dace.SDFG, _) -> Dict[int, UnderapproximateWrites for sdfg in top_sdfg.all_sdfgs_recursive(): # Clear the global dictionaries. - self.approximation_dict.clear() - self.loop_write_dict.clear() - self.loop_dict.clear() - self.iteration_variables.clear() - self.ranges_per_state.clear() + self.approximation_dict = {} + self.loop_write_dict = {} + self.loop_dict = {} + self.iteration_variables = {} + self.ranges_per_state = defaultdict(lambda: {}) # fill the approximation dictionary with the original edges as keys and the edges with the # approximated memlets as values diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1c038dd2e4..6447d8f89b 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -4,21 +4,22 @@ from internal memory accesses and scope ranges). """ -from collections import deque import copy -from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -import itertools import functools +import itertools +import warnings +from collections import deque +from typing import List, Set + import sympy -from sympy import ceiling, Symbol +from sympy import Symbol, ceiling from sympy.concrete.summations import Sum -import warnings -import networkx as nx -from dace import registry, subsets, symbolic, dtypes, data +from dace import data, dtypes, registry, subsets, symbolic from dace.memlet import Memlet -from dace.sdfg import nodes, graph as gr -from typing import List, Set +from dace.sdfg import graph as gr +from dace.sdfg import nodes +from dace.symbolic import issymbolic, pystr_to_symbolic, simplify @registry.make_registry @@ -569,8 +570,8 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop condition_edges = {} @@ -739,8 +740,8 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: # We import here to avoid cyclic imports. from dace.sdfg import InterstateEdge - from dace.transformation.helpers import split_interstate_edges from dace.sdfg.analysis import cfg + from dace.transformation.helpers import split_interstate_edges # Reset the state edge annotations (which may have changed due to transformations) reset_state_annotations(sdfg) diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py index 92b2badecf..265c6465ba 100644 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -214,7 +214,7 @@ def _recursive_get_deps_region(self, cfg: ControlFlowRegion, def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict[str, Any]) -> Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]]: """ - :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. + :return: For each CFG, a dictionary mapping control flow regions to sets of their input and output memlets. """ results = defaultdict(lambda: defaultdict(lambda: [defaultdict(set), defaultdict(set)])) diff --git a/tests/passes/analysis/control_flow_region_analysis_test.py b/tests/passes/analysis/control_flow_region_analysis_test.py index 64461edd85..d1ea5161bf 100644 --- a/tests/passes/analysis/control_flow_region_analysis_test.py +++ b/tests/passes/analysis/control_flow_region_analysis_test.py @@ -4,33 +4,36 @@ import dace from dace.memlet import Memlet -from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.propagation import propagate_memlets_sdfg +from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.state import LoopRegion, SDFGState from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence, StateDataDependence def test_simple_state_data_dependence_with_self_contained_read(): + sdfg = SDFG('myprog') N = dace.symbol('N') - - @dace.program - def myprog(A: dace.float64[N], B: dace.float64): - for i in dace.map[0:N/2]: - with dace.tasklet: - in1 << B[i] - out1 >> A[i] - out1 = in1 + 1 - with dace.tasklet: - in1 << B[i] - out1 >> B[N - (i + 1)] - out1 = in1 - 1 - for i in dace.map[0:N/2]: - with dace.tasklet: - in1 << A[i] - out1 >> B[i] - out1 = in1 * 2 - - sdfg = myprog.to_sdfg() + sdfg.add_array('A', (N,), dace.float32) + sdfg.add_array('B', (N,), dace.float32) + mystate = sdfg.add_state('mystate', is_start_block=True) + b_read = mystate.add_access('B') + b_write_second_half = mystate.add_access('B') + b_write_first_half = mystate.add_access('B') + a_read_write = mystate.add_access('A') + first_entry, first_exit = mystate.add_map('map_one', {'i': '0:0.5*N'}) + second_entry, second_exit = mystate.add_map('map_two', {'i': '0:0.5*N'}) + t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 + 1.0') + t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') + t3 = mystate.add_tasklet('t3', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') + mystate.add_memlet_path(b_read, first_entry, t1, memlet=Memlet('B[i]'), dst_conn='i1') + mystate.add_memlet_path(b_read, first_entry, t2, memlet=Memlet('B[i]'), dst_conn='i1') + mystate.add_memlet_path(t1, first_exit, a_read_write, memlet=Memlet('A[i]'), src_conn='o1') + mystate.add_memlet_path(t2, first_exit, b_write_second_half, memlet=Memlet('B[N - (i + 1)]'), src_conn='o1') + mystate.add_memlet_path(a_read_write, second_entry, t3, memlet=Memlet('A[i]'), dst_conn='i1') + mystate.add_memlet_path(t3, second_exit, b_write_first_half, memlet=Memlet('B[i]'), src_conn='o1') + + propagate_memlets_sdfg(sdfg) res = {} Pipeline([StateDataDependence()]).apply_pass(sdfg, res) @@ -40,7 +43,7 @@ def myprog(A: dace.float64[N], B: dace.float64): read_memlet: Memlet = list(state_data_deps[0])[0] assert read_memlet.data == 'B' assert read_memlet.subset[0][0] == 0 - assert read_memlet.subset[0][1] == 0.5 * N - 1 or read_memlet.subset[0][1] == N / 2 - 1 + assert read_memlet.subset[0][1] == 0.5 * N - 1 assert len(state_data_deps[1]) == 3 diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 7d5272d80a..d0c0e03209 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -9,8 +9,6 @@ M = dace.symbol("M") K = dace.symbol("K") -pipeline = Pipeline([UnderapproximateWrites()]) - def test_2D_map_overwrites_2D_array(): """ @@ -33,9 +31,10 @@ def test_2D_map_overwrites_2D_array(): output_nodes={'B': a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results['approximation'] + result = results[sdfg.cfg_id]['approximation'] edge = map_state.in_edges(a1)[0] result_subset_list = result[edge].subset.subset_list result_subset = result_subset_list[0] @@ -65,9 +64,10 @@ def test_2D_map_added_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -94,9 +94,10 @@ def test_2D_map_multiplied_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -121,9 +122,10 @@ def test_1D_map_one_index_multiple_dims(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -146,9 +148,10 @@ def test_1D_map_one_index_squared(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -185,9 +188,10 @@ def test_map_tree_full_write(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge = Range.from_string("0:M, _i") result_inner_edge_0 = result[inner_edge_0].subset.subset_list[0] @@ -230,9 +234,10 @@ def test_map_tree_no_write_multiple_indices(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] result_inner_edge_0 = result[inner_edge_0].subset.subset_list result_inner_edge_1 = result[inner_edge_1].subset.subset_list result_outer_edge = result[outer_edge].subset.subset_list @@ -273,9 +278,10 @@ def test_map_tree_multiple_indices_per_dimension(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge_1 = Range.from_string("0:M, _i") result_inner_edge_1 = result[inner_edge_1].subset.subset_list[0] @@ -300,11 +306,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] nsdfg = sdfg.cfg_list[1].parent_nsdfg_node map_state = sdfg.states()[0] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] edge = map_state.out_edges(nsdfg)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -323,11 +330,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] map_state = sdfg.states()[0] edge = map_state.in_edges(map_state.data_nodes()[0])[0] - result = results["approximation"] + result = results[sdfg.cfg_id]["approximation"] expected_subset = Range.from_string("0:N, 0:M") assert (str(result[edge].subset.subset_list[0]) == str(expected_subset)) @@ -357,9 +365,10 @@ def test_map_in_loop(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] expected_subset = Range.from_string("0:N, 0:M") assert (str(result[guard]["B"].subset.subset_list[0]) == str(expected_subset)) @@ -390,9 +399,10 @@ def test_map_in_loop_multiplied_indices_first_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] assert (guard not in result.keys() or len(result[guard]) == 0) @@ -421,9 +431,10 @@ def test_map_in_loop_multiplied_indices_second_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] assert (guard not in result.keys() or len(result[guard]) == 0) @@ -444,8 +455,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None @@ -478,9 +490,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None @@ -510,15 +523,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -542,15 +556,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -574,7 +589,8 @@ def test_simple_loop_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) @@ -598,7 +614,8 @@ def test_loop_2D_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -629,7 +646,8 @@ def test_loop_2D_propagation_gap_symbolic(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert ("A" not in result[guard1].keys()) assert ("A" not in result[guard2].keys()) @@ -657,7 +675,8 @@ def test_2_loops_overwrite(): loop_tasklet_2 = loop_body_2.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_2.add_edge(loop_tasklet_2, "a", a1, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard_1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard_2]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) @@ -687,7 +706,8 @@ def test_loop_2D_overwrite_propagation_gap_non_empty(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -717,7 +737,8 @@ def test_loop_nest_multiplied_indices(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i,i*j]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -748,7 +769,8 @@ def test_loop_nest_empty_nested_loop(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -779,7 +801,8 @@ def test_loop_nest_inner_loop_conditional(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[k]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id]["loop_approximation"] assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 in result.keys() and "A" in result[guard2].keys() and str(result[guard2]['A'].subset) == "0:N") @@ -799,9 +822,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] write_set = None accessnode = None for node, _ in sdfg.all_nodes_recursive(): @@ -828,10 +852,11 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] # find write set - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id]["approximation"] accessnode = None write_set = None for node, _ in sdfg.all_nodes_recursive(): @@ -864,9 +889,10 @@ def test_loop_break(): loop_tasklet = loop_body_1.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_1.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id]["loop_approximation"] assert (guard3 not in result.keys() or "A" not in result[guard3].keys()) From b61a283e1c624da2c9ef4b8634e2081eb5b15159 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 17 Sep 2024 16:09:27 +0200 Subject: [PATCH 007/108] Add tests --- .../analysis/control_flow_region_analysis.py | 2 +- .../control_flow_region_analysis_test.py | 100 ++++++++++++------ 2 files changed, 68 insertions(+), 34 deletions(-) diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py index 265c6465ba..377765c31b 100644 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ b/dace/transformation/passes/analysis/control_flow_region_analysis.py @@ -35,7 +35,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.Nodes | ppl.Modifies.Memlets) def depends_on(self): - return {UnderapproximateWrites, AccessRanges} + return {UnderapproximateWrites} def _gather_reads_scope(self, state: SDFGState, scope: ScopeTree, writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]], diff --git a/tests/passes/analysis/control_flow_region_analysis_test.py b/tests/passes/analysis/control_flow_region_analysis_test.py index d1ea5161bf..bf0742f3f1 100644 --- a/tests/passes/analysis/control_flow_region_analysis_test.py +++ b/tests/passes/analysis/control_flow_region_analysis_test.py @@ -1,21 +1,18 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests analysis passes related to control flow regions (control_flow_region_analysis.py). """ - import dace from dace.memlet import Memlet -from dace.sdfg.propagation import propagate_memlets_sdfg -from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import LoopRegion, SDFGState +from dace.sdfg.sdfg import SDFG from dace.transformation.pass_pipeline import Pipeline -from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence, StateDataDependence +from dace.transformation.passes.analysis.control_flow_region_analysis import StateDataDependence -def test_simple_state_data_dependence_with_self_contained_read(): +def test_state_data_dependence_with_contained_read(): sdfg = SDFG('myprog') N = dace.symbol('N') - sdfg.add_array('A', (N,), dace.float32) - sdfg.add_array('B', (N,), dace.float32) + sdfg.add_array('A', (N, ), dace.float32) + sdfg.add_array('B', (N, ), dace.float32) mystate = sdfg.add_state('mystate', is_start_block=True) b_read = mystate.add_access('B') b_write_second_half = mystate.add_access('B') @@ -33,8 +30,6 @@ def test_simple_state_data_dependence_with_self_contained_read(): mystate.add_memlet_path(a_read_write, second_entry, t3, memlet=Memlet('A[i]'), dst_conn='i1') mystate.add_memlet_path(t3, second_exit, b_write_first_half, memlet=Memlet('B[i]'), src_conn='o1') - propagate_memlets_sdfg(sdfg) - res = {} Pipeline([StateDataDependence()]).apply_pass(sdfg, res) state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] @@ -48,36 +43,75 @@ def test_simple_state_data_dependence_with_self_contained_read(): assert len(state_data_deps[1]) == 3 -''' -def test_nested_cf_region_data_dependence(): +def test_state_data_dependence_with_contained_read_in_map(): + sdfg = SDFG('myprog') N = dace.symbol('N') + sdfg.add_array('A', (N, ), dace.float32) + sdfg.add_transient('tmp', (N, ), dace.float32) + sdfg.add_array('B', (N, ), dace.float32) + mystate = sdfg.add_state('mystate', is_start_block=True) + a_read = mystate.add_access('A') + tmp = mystate.add_access('tmp') + b_write = mystate.add_access('B') + m_entry, m_exit = mystate.add_map('my_map', {'i': 'N'}) + t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') + t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') + mystate.add_memlet_path(a_read, m_entry, t1, memlet=Memlet('A[i]'), dst_conn='i1') + mystate.add_memlet_path(t1, tmp, memlet=Memlet('tmp[i]'), src_conn='o1') + mystate.add_memlet_path(tmp, t2, memlet=Memlet('tmp[i]'), dst_conn='i1') + mystate.add_memlet_path(t2, m_exit, b_write, memlet=Memlet('B[i]'), src_conn='o1') - @dace.program - def myprog(A: dace.float64[N], B: dace.float64): - for i in range(N): - with dace.tasklet: - in1 << B[i] - out1 >> A[i] - out1 = in1 + 1 - for i in range(N): - with dace.tasklet: - in1 << A[i] - out1 >> B[i] - out1 = in1 * 2 + res = {} + Pipeline([StateDataDependence()]).apply_pass(sdfg, res) + state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] - myprog.use_experimental_cfg_blocks = True + assert len(state_data_deps[0]) == 1 + read_memlet: Memlet = list(state_data_deps[0])[0] + assert read_memlet.data == 'A' - sdfg = myprog.to_sdfg() + assert len(state_data_deps[1]) == 2 + out_containers = [m.data for m in state_data_deps[1]] + assert 'B' in out_containers + assert 'tmp' in out_containers + assert 'A' not in out_containers + + +def test_state_data_dependence_with_non_contained_read_in_map(): + sdfg = SDFG('myprog') + N = dace.symbol('N') + sdfg.add_array('A', (N, ), dace.float32) + sdfg.add_array('tmp', (N, ), dace.float32) + sdfg.add_array('B', (N, ), dace.float32) + mystate = sdfg.add_state('mystate', is_start_block=True) + a_read = mystate.add_access('A') + tmp = mystate.add_access('tmp') + b_write = mystate.add_access('B') + m_entry, m_exit = mystate.add_map('my_map', {'i': '0:ceil(N/2)'}) + t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') + t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') + mystate.add_memlet_path(a_read, m_entry, t1, memlet=Memlet('A[i]'), dst_conn='i1') + mystate.add_memlet_path(t1, tmp, memlet=Memlet('tmp[i]'), src_conn='o1') + mystate.add_memlet_path(tmp, t2, memlet=Memlet('tmp[i+ceil(N/2)]'), dst_conn='i1') + mystate.add_memlet_path(t2, m_exit, b_write, memlet=Memlet('B[i]'), src_conn='o1') res = {} - pipeline = Pipeline([CFGDataDependence()]) - pipeline.__experimental_cfg_block_compatible__ = True - pipeline.apply_pass(sdfg, res) + Pipeline([StateDataDependence()]).apply_pass(sdfg, res) + state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] + + assert len(state_data_deps[0]) == 2 + in_containers = [m.data for m in state_data_deps[0]] + assert 'A' in in_containers + assert 'tmp' in in_containers + assert 'B' not in in_containers - print(sdfg) - ''' + assert len(state_data_deps[1]) == 2 + out_containers = [m.data for m in state_data_deps[1]] + assert 'B' in out_containers + assert 'tmp' in out_containers + assert 'A' not in out_containers if __name__ == '__main__': - test_simple_state_data_dependence_with_self_contained_read() - #test_nested_cf_region_data_dependence() + test_state_data_dependence_with_contained_read() + test_state_data_dependence_with_contained_read_in_map() + test_state_data_dependence_with_non_contained_read_in_map() From 05b1c28847af2a2f222ed36342fd8da0cbaefb32 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 18 Sep 2024 18:19:28 +0200 Subject: [PATCH 008/108] Add loop lifting capabilities --- dace/codegen/control_flow.py | 13 +- dace/sdfg/state.py | 21 ++- .../interstate/loop_detection.py | 53 ++++-- .../transformation/interstate/loop_lifting.py | 112 ++++++++++++ .../simplification/control_flow_raising.py | 22 +++ .../interstate/loop_lifting_test.py | 164 ++++++++++++++++++ 6 files changed, 358 insertions(+), 27 deletions(-) create mode 100644 dace/transformation/interstate/loop_lifting.py create mode 100644 dace/transformation/passes/simplification/control_flow_raising.py create mode 100644 tests/transformations/interstate/loop_lifting_test.py diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index ae9351fc43..d170d04e77 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -270,10 +270,17 @@ def as_cpp(self, codegen, symbols) -> str: for i, elem in enumerate(self.elements): expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. - if isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region): - cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph + if (isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region) or + (isinstance(elem, GeneralLoopScope) and elem.loop)): + if isinstance(elem, BasicCFBlock): + g_elem = elem.state + elif isinstance(elem, GeneralBlock): + g_elem = elem.region + else: + g_elem = elem.loop + cfg = g_elem.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg - out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region) + out_edges = cfg.out_edges(g_elem) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: # Skip gotos to immediate successors diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index e8a8161747..7fcdc34e3e 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2965,26 +2965,35 @@ class LoopRegion(ControlFlowRegion): def __init__(self, label: str, - condition_expr: Optional[str] = None, + condition_expr: Optional[Union[str, CodeBlock]] = None, loop_var: Optional[str] = None, - initialize_expr: Optional[str] = None, - update_expr: Optional[str] = None, + initialize_expr: Optional[Union[str, CodeBlock]] = None, + update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, sdfg: Optional['SDFG'] = None): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: - self.init_statement = CodeBlock(initialize_expr) + if isinstance(initialize_expr, CodeBlock): + self.init_statement = initialize_expr + else: + self.init_statement = CodeBlock(initialize_expr) else: self.init_statement = None if condition_expr: - self.loop_condition = CodeBlock(condition_expr) + if isinstance(condition_expr, CodeBlock): + self.loop_condition = condition_expr + else: + self.loop_condition = CodeBlock(condition_expr) else: self.loop_condition = CodeBlock('True') if update_expr is not None: - self.update_statement = CodeBlock(update_expr) + if isinstance(update_expr, CodeBlock): + self.update_statement = update_expr + else: + self.update_statement = CodeBlock(update_expr) else: self.update_statement = None diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 93c2f6ea1c..de3ed9c04b 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ import sympy as sp @@ -77,19 +77,20 @@ def can_be_applied(self, sdfg: sd.SDFG, permissive: bool = False) -> bool: if expr_index == 0: - return self.detect_loop(graph, False) is not None + return self.detect_loop(graph, False, permissive) is not None elif expr_index == 1: - return self.detect_loop(graph, True) is not None + return self.detect_loop(graph, True, permissive) is not None elif expr_index == 2: - return self.detect_rotated_loop(graph, False) is not None + return self.detect_rotated_loop(graph, False, permissive) is not None elif expr_index == 3: - return self.detect_rotated_loop(graph, True) is not None + return self.detect_rotated_loop(graph, True, permissive) is not None elif expr_index == 4: - return self.detect_self_loop(graph) is not None + return self.detect_self_loop(graph, permissive) is not None raise ValueError(f'Invalid expression index {expr_index}') - def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -159,13 +160,19 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Option # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) - def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -234,13 +241,18 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) - def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: + def detect_self_loop(self, graph: ControlFlowRegion, accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -288,9 +300,14 @@ def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py new file mode 100644 index 0000000000..52e6e6e540 --- /dev/null +++ b/dace/transformation/interstate/loop_lifting.py @@ -0,0 +1,112 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace import properties +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.transformation import transformation +from dace.transformation.interstate.loop_detection import DetectLoop + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class LoopLifting(DetectLoop, transformation.MultiStateTransformation): + + def can_be_applied(self, graph: transformation.ControlFlowRegion, expr_index: int, sdfg: transformation.SDFG, + permissive: bool = False) -> bool: + # Check loop detection with permissive = True, which allows loops where no iteration variable could be detected. + # We want this to detect while loops. + if not super().can_be_applied(graph, expr_index, sdfg, permissive=True): + return False + + # Check that there's a condition edge, that's the only requirement to lift it into loop. + cond_edge = self.loop_condition_edge() + if not cond_edge or cond_edge.data.condition is None: + return False + return True + + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): + first_state = self.loop_guard if self.expr_index <= 1 else self.loop_begin + after = self.exit_state + + loop_info = self.loop_information() + + body = self.loop_body() + meta = self.loop_meta_states() + full_body = set(body) + full_body.update(meta) + cond_edge = self.loop_condition_edge() + incr_edge = self.loop_increment_edge() + inverted = cond_edge is incr_edge + init_edge = self.loop_init_edge() + exit_edge = self.loop_exit_edge() + + label = 'loop_' + first_state.label + if loop_info is None: + itvar = None + init_expr = None + incr_expr = None + else: + incr_expr = f'{loop_info[0]} = {incr_edge.data.assignments[loop_info[0]]}' + init_expr = f'{loop_info[0]} = {init_edge.data.assignments[loop_info[0]]}' + itvar = loop_info[0] + + left_over_assignments = {} + for k in init_edge.data.assignments.keys(): + if k != itvar: + left_over_assignments[k] = init_edge.data.assignments[k] + left_over_incr_assignments = {} + for k in incr_edge.data.assignments.keys(): + if k != itvar: + left_over_incr_assignments[k] = incr_edge.data.assignments[k] + + # TODO(later): In the case of inverted loops with non-loop-variable assignmentes AND the loop latch condition on + # the backedge, do not perform lifting for now. Note, the functionality in the lifting is there (see below, + # where left over increment assignments are used), but a bug in our control-flow-detection in codegen currently + # leads to wrong code being generated by this niche case. Remove the following check if the bug is fixed, and + # then these loops will also be lifted correctly. + if left_over_incr_assignments != {} and inverted: + return + + loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, + update_expr=incr_expr, inverted=inverted, sdfg=sdfg) + + graph.add_node(loop) + graph.add_edge(init_edge.src, loop, + InterstateEdge(condition=init_edge.data.condition, assignments=left_over_assignments)) + graph.add_edge(loop, after, InterstateEdge(assignments=exit_edge.data.assignments)) + + loop.add_node(first_state, is_start_block=True) + for n in full_body: + if n is not first_state: + loop.add_node(n) + added = set() + for e in graph.all_edges(*full_body): + if e.src in full_body and e.dst in full_body: + if not e in added: + added.add(e) + if e is incr_edge: + if left_over_incr_assignments != {}: + # If there are left over increments in an inverted loop, only execute them if the condition + # still holds. This is due to SDFG semantics, where interstate assignments are only executed + # if the condition on the edge holds (i.e., the edge is taken). This must be reflected in + # the raised loop. This is a very niche case - specifically, a do-while, where there is a + # non-loop-variable assignment AND the loop latch condition on the back-edge. + left_over_increment_cond = None + if inverted: + left_over_increment_cond = cond_edge.data.condition + + loop.add_edge(e.src, loop.add_state(label + '_tail'), + InterstateEdge(assignments=left_over_incr_assignments, + condition=left_over_increment_cond)) + elif e is cond_edge: + e.data.condition = properties.CodeBlock('1') + loop.add_edge(e.src, e.dst, e.data) + else: + loop.add_edge(e.src, e.dst, e.data) + + # Remove old loop. + for n in full_body: + graph.remove_node(n) + + sdfg.recheck_using_experimental_blocks() + sdfg.reset_cfg_list() diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py new file mode 100644 index 0000000000..5cad716176 --- /dev/null +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -0,0 +1,22 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace import properties +from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.interstate.loop_lifting import LoopLifting + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowRaising(ppl.Pass): + + CATEGORY: str = 'Simplification' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def apply_pass(self, top_sdfg: ppl.SDFG, _) -> ppl.Any | None: + for sdfg in top_sdfg.all_sdfgs_recursive(): + sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py new file mode 100644 index 0000000000..e7a026a812 --- /dev/null +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -0,0 +1,164 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests loop raising trainsformations. """ + +import numpy as np +import dace +from dace.memlet import Memlet +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import LoopRegion +from dace.transformation.interstate.loop_lifting import LoopLifting + + +def test_lift_regular_for_loop(): + sdfg = SDFG('regular_for') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + start_state = sdfg.add_state('start', is_start_block=True) + init_state = sdfg.add_state('start', is_start_block=True) + guard_state = sdfg.add_state('guard') + main_state = sdfg.add_state('loop_state') + loop_exit = sdfg.add_state('exit') + final_state = sdfg.add_state('final') + sdfg.add_edge(start_state, init_state, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(init_state, guard_state, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(guard_state, main_state, InterstateEdge(condition='i < N')) + sdfg.add_edge(main_state, guard_state, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(guard_state, loop_exit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loop_exit, final_state, InterstateEdge()) + a_access = main_state.add_access('A') + w_tasklet = main_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + main_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loop_exit.add_access('A') + w_tasklet_2 = loop_exit.add_tasklet('t1', {}, {'out'}, 'out = k') + loop_exit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = final_state.add_access('A') + w_tasklet_3 = final_state.add_tasklet('t1', {}, {'out'}, 'out = j') + final_state.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_lift_loop_llvm_canonical(): + sdfg = dace.SDFG('llvm_canonical') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(body, latch, InterstateEdge()) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_lift_loop_llvm_canonical_while(): + sdfg = dace.SDFG('llvm_canonical_while') + N = dace.symbol('N') + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + sdfg.add_scalar('i', dace.int32, transient=True) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'k': 0})) + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + i_init_write = entry.add_access('i') + iw_init_tasklet = entry.add_tasklet('ti', {}, {'out'}, 'out = 0') + entry.add_edge(iw_init_tasklet, 'out', i_init_write, None, Memlet('i[0]')) + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + i_read = body.add_access('i') + i_write = body.add_access('i') + iw_tasklet = body.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2') + body.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]')) + body.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +if __name__ == '__main__': + test_lift_regular_for_loop() + test_lift_loop_llvm_canonical() + test_lift_loop_llvm_canonical_while() From f08d95e858ca093eef4e34ce257082691b08b587 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 19 Sep 2024 12:44:22 +0200 Subject: [PATCH 009/108] Adjust loop detection to LLVM canonical semantics --- .../analysis/writeset_underapproximation.py | 21 +++--- dace/sdfg/propagation.py | 15 ++-- .../interstate/loop_detection.py | 71 ++++++++++++------- .../transformation/interstate/loop_lifting.py | 22 +----- .../interstate/loop_lifting_test.py | 4 +- tests/transformations/loop_detection_test.py | 8 +-- 6 files changed, 71 insertions(+), 70 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index e1b88f9401..0d2fd989a3 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -153,19 +153,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].expr) - elif isinstance(expr[i], tuple): - dexprs.append(( - expr[i][0].expr if isinstance( - expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].expr if isinstance( - expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].expr if isinstance( - expr[i][2], symbolic.SymExpr) else expr[i][2], - expr.tile_sizes[i])) + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.expr) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].expr if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].expr if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].expr if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], + expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, node_range) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 6447d8f89b..1de7ce3977 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -94,15 +94,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].approx) - elif isinstance(expr[i], tuple): - dexprs.append((expr[i][0].approx if isinstance(expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].approx if isinstance(expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].approx if isinstance(expr[i][2], symbolic.SymExpr) else expr[i][2], + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.approx) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].approx if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].approx if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].approx if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, overapprox_range) diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index de3ed9c04b..95056eb344 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ +from re import I import sympy as sp import networkx as nx from typing import AnyStr, Optional, Tuple, List, Set @@ -199,14 +200,10 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if len(latch_outedges) != 2: return None - # All incoming edges to the start of the loop must set the same variable - itvar = None - for iedge in begin_inedges: - if itvar is None: - itvar = set(iedge.data.assignments.keys()) - else: - itvar &= iedge.data.assignments.keys() - if itvar is None: + # A for-loop latch can further only have one incoming edge (the increment edge). A while-loop, i.e., a loop + # with no explicit iteration variable, may have more than that. + latch_inedges = graph.in_edges(latch) + if not accept_missing_itvar and len(latch_inedges) != 1: return None # Outgoing edges must be a negation of each other @@ -238,8 +235,22 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if backedge is None: return None - # The backedge must reassign the iteration variable - itvar &= backedge.data.assignments.keys() + # The iteration variable must be reassigned on all incoming edges to the latch block. + # If an assignment overlap of exactly one variable is found between the initialization edge and the edges + # going into the latch block, that will be the iteration variable. + itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() + itvar_edge_set.update(begin_inedges) + itvar_edge_set.update(latch_inedges) + itvar = None + for iedge in itvar_edge_set: + if iedge is backedge: + continue + if itvar is None: + itvar = set(iedge.data.assignments.keys()) + else: + itvar &= iedge.data.assignments.keys() + if itvar is None: + return None if len(itvar) != 1: if not accept_missing_itvar: # Either no consistent iteration variable found, or too many consistent iteration variables found @@ -430,7 +441,7 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: return next(e for e in graph.in_edges(guard) if e.src in body) elif self.expr_index in (2, 3): body = self.loop_body() - return next(e for e in graph.in_edges(begin) if e.src in body) + return graph.in_edges(self.loop_latch)[0] elif self.expr_index == 4: return graph.edges_between(begin, begin)[0] @@ -554,8 +565,7 @@ def find_rotated_for_loop( """ Finds rotated loop range from state machine. - :param latch: State from which the outgoing edges detect whether to exit - the loop or not. + :param latch: State from which the outgoing edges detect whether to exit the loop or not. :param entry: First state in the loop body. :param itervar: An optional field that overrides the analyzed iteration variable. :return: (iteration variable, (start, end, stride), @@ -565,11 +575,20 @@ def find_rotated_for_loop( # Extract state transition edge information entry_inedges = graph.in_edges(entry) condition_edge = graph.edges_between(latch, entry)[0] + latch_inedges = graph.in_edges(latch) - # All incoming edges to the loop entry must set the same variable + self_loop = latch is entry if itervar is None: + # The iteration variable must be reassigned on all incoming edges to the latch block. + # If an assignment overlap of exactly one variable is found between the initialization edge and the edges + # going into the latch block, that will be the iteration variable. + itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() + itvar_edge_set.update(entry_inedges) + itvar_edge_set.update(latch_inedges) itervars = None - for iedge in entry_inedges: + for iedge in itvar_edge_set: + if iedge is condition_edge and not self_loop: + continue if itervars is None: itervars = set(iedge.data.assignments.keys()) else: @@ -587,18 +606,12 @@ def find_rotated_for_loop( # have one assignment. init_edges = [] init_assignment = None - step_edge = None itersym = symbolic.symbol(itervar) for iedge in entry_inedges: + if iedge is condition_edge: + continue assignment = iedge.data.assignments[itervar] - if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols: - if step_edge is None: - step_edge = iedge - else: - # More than one edge with the iteration variable as a free - # symbol, which is not legal. Invalid for loop. - return None - else: + if itersym not in symbolic.pystr_to_symbolic(assignment).free_symbols: if init_assignment is None: init_assignment = assignment init_edges.append(iedge) @@ -608,10 +621,18 @@ def find_rotated_for_loop( return None else: init_edges.append(iedge) - if step_edge is None or len(init_edges) == 0 or init_assignment is None: + if len(init_edges) == 0 or init_assignment is None: # Less than two assignment variations, can't be a valid for loop. return None + if self_loop: + step_edge = condition_edge + else: + step_edge = None if len(latch_inedges) != 1 else latch_inedges[0] + if step_edge is None: + # No explicit step edge found. + return None + # Get the init expression and the stride. start = symbolic.pystr_to_symbolic(init_assignment) stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) - itersym) diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 52e6e6e540..54363dd8e2 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -36,7 +36,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): full_body.update(meta) cond_edge = self.loop_condition_edge() incr_edge = self.loop_increment_edge() - inverted = cond_edge is incr_edge + inverted = self.expr_index in (2, 3) init_edge = self.loop_init_edge() exit_edge = self.loop_exit_edge() @@ -59,14 +59,6 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): if k != itvar: left_over_incr_assignments[k] = incr_edge.data.assignments[k] - # TODO(later): In the case of inverted loops with non-loop-variable assignmentes AND the loop latch condition on - # the backedge, do not perform lifting for now. Note, the functionality in the lifting is there (see below, - # where left over increment assignments are used), but a bug in our control-flow-detection in codegen currently - # leads to wrong code being generated by this niche case. Remove the following check if the bug is fixed, and - # then these loops will also be lifted correctly. - if left_over_incr_assignments != {} and inverted: - return - loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, update_expr=incr_expr, inverted=inverted, sdfg=sdfg) @@ -86,18 +78,8 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): added.add(e) if e is incr_edge: if left_over_incr_assignments != {}: - # If there are left over increments in an inverted loop, only execute them if the condition - # still holds. This is due to SDFG semantics, where interstate assignments are only executed - # if the condition on the edge holds (i.e., the edge is taken). This must be reflected in - # the raised loop. This is a very niche case - specifically, a do-while, where there is a - # non-loop-variable assignment AND the loop latch condition on the back-edge. - left_over_increment_cond = None - if inverted: - left_over_increment_cond = cond_edge.data.condition - loop.add_edge(e.src, loop.add_state(label + '_tail'), - InterstateEdge(assignments=left_over_incr_assignments, - condition=left_over_increment_cond)) + InterstateEdge(assignments=left_over_incr_assignments)) elif e is cond_edge: e.data.condition = properties.CodeBlock('1') loop.add_edge(e.src, e.dst, e.data) diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index e7a026a812..e3098d4e5c 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -72,8 +72,8 @@ def test_lift_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) - sdfg.add_edge(body, latch, InterstateEdge()) - sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) sdfg.add_edge(loopexit, exitstate, InterstateEdge()) diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 5469f45762..891d520f41 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -37,8 +37,8 @@ def test_loop_rotated(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() @@ -106,8 +106,8 @@ def test_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) From 1d903463b9277b0d1eaad713d26e682b7964f1a6 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 19 Sep 2024 13:17:20 +0200 Subject: [PATCH 010/108] Test fix --- tests/transformations/loop_to_map_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 2cab97da78..12d4898858 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -741,8 +741,8 @@ def test_rotated_loop_to_map(simplify): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) From 6b5ef0ce115ce33a34c489bc42c479cb7d9f5f6a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 19 Sep 2024 17:40:25 +0200 Subject: [PATCH 011/108] Remove unnecessary imports --- .../analysis/writeset_underapproximation.py | 24 +++++++++---------- dace/sdfg/propagation.py | 16 ++++++------- dace/transformation/subgraph/expansion.py | 9 ++----- dace/transformation/subgraph/helpers.py | 17 ++++--------- .../writeset_underapproximation_test.py | 1 + 5 files changed, 26 insertions(+), 41 deletions(-) diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index 0d2fd989a3..557ee8a73b 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -82,27 +82,26 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for dim in range(data_dims): dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].expr) - elif isinstance(expr[dim], tuple): - dexprs.append( - (expr[dim][0].expr if isinstance(expr[dim][0], symbolic.SymExpr) else - expr[dim][0], expr[dim][1].expr if isinstance( - expr[dim][1], symbolic.SymExpr) else expr[dim][1], expr[dim][2].expr - if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.expr) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].expr if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].expr if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].expr if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableUnderapproximationMemletPattern.extensions().keys(): smpattern = pattern_class() - if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, - data_dims): + if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, data_dims): self.patterns_per_dim[dim] = smpattern break return None not in self.patterns_per_dim def _iteration_variables_appear_multiple_times(self, data_dims, expressions, other_params, params): + # TODO: This name implies exactly the inverse of the returned value.. for expr in expressions: for param in params: occured_before = False @@ -139,8 +138,7 @@ def _iteration_variables_appear_multiple_times(self, data_dims, expressions, oth def _make_range(self, node_range): return subsets.Range([(rb.expr if isinstance(rb, symbolic.SymExpr) else rb, - re.expr if isinstance( - re, symbolic.SymExpr) else re, + re.expr if isinstance(re, symbolic.SymExpr) else re, rs.expr if isinstance(rs, symbolic.SymExpr) else rs) for rb, re, rs in node_range]) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1de7ce3977..f62bb6eb58 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -62,17 +62,17 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for rb, re, rs in node_range]) for dim in range(data_dims): - dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].approx) - elif isinstance(expr[dim], tuple): - dexprs.append((expr[dim][0].approx if isinstance(expr[dim][0], symbolic.SymExpr) else expr[dim][0], - expr[dim][1].approx if isinstance(expr[dim][1], symbolic.SymExpr) else expr[dim][1], - expr[dim][2].approx if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.approx) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].approx if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].approx if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].approx if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableMemletPattern.extensions().keys(): smpattern = pattern_class() diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index db1e9b59ab..aa182e8c80 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -1,26 +1,21 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement the expansion transformation. """ -from dace import dtypes, registry, symbolic, subsets +from dace import dtypes, symbolic, subsets from dace.sdfg import nodes -from dace.memlet import Memlet from dace.sdfg import replace, SDFG, dynamic_map_inputs from dace.sdfg.graph import SubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg from dace.transformation.subgraph import helpers from collections import defaultdict from copy import deepcopy as dcpy -from typing import List, Union import itertools -import dace.libraries.standard as stdlib import warnings -import sys def offset_map(state, map_entry): diff --git a/dace/transformation/subgraph/helpers.py b/dace/transformation/subgraph/helpers.py index b2af49c879..0ea1903522 100644 --- a/dace/transformation/subgraph/helpers.py +++ b/dace/transformation/subgraph/helpers.py @@ -1,20 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Subgraph Transformation Helper API """ -from dace import dtypes, registry, symbolic, subsets -from dace.sdfg import nodes, utils -from dace.memlet import Memlet -from dace.sdfg import replace, SDFG, SDFGState -from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg +from dace import subsets +from dace.sdfg import nodes from dace.sdfg.graph import SubgraphView -from collections import defaultdict import copy -from typing import List, Union, Dict, Tuple, Set - -import dace.libraries.standard as stdlib - -import itertools +from typing import List, Dict, Set # **************** # Helper functions diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index d0c0e03209..d27683b801 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -545,6 +545,7 @@ def test_nested_sdfg_in_map_branches(): Nested SDFG that overwrites second dimension of array conditionally. --> should approximate write-set of map as empty """ + # No, should be approximated precisely - at least certainly with CF regions..? @dace.program def nested_loop(A: dace.float64[M, N]): From 23af03863c9f0a4196b6978866cf220f5491f4ec Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 12:26:12 +0200 Subject: [PATCH 012/108] Improved loop detection --- .../interstate/loop_detection.py | 242 +++++++++++++----- .../simplification/control_flow_raising.py | 2 +- tests/transformations/loop_detection_test.py | 51 ++-- 3 files changed, 206 insertions(+), 89 deletions(-) diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 95056eb344..daf13599fe 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,10 +1,9 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ -from re import I import sympy as sp import networkx as nx -from typing import AnyStr, Optional, Tuple, List, Set +from typing import AnyStr, Iterable, Optional, Tuple, List, Set from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge @@ -30,6 +29,9 @@ class DetectLoop(transformation.PatternTransformation): # Available for rotated and self loops entry_state = transformation.PatternNode(sd.SDFGState) + # Available for explicit-latch rotated loops + loop_break = transformation.PatternNode(sd.SDFGState) + @classmethod def expressions(cls): # Case 1: Loop with one state @@ -70,7 +72,32 @@ def expressions(cls): ssdfg.add_edge(cls.loop_begin, cls.loop_begin, sd.InterstateEdge()) ssdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) - return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg] + # Case 6: Rotated multi-state loop with explicit exiting and latch states + mlrmsdfg = gr.OrderedDiGraph() + mlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + # Case 7: Rotated single-state loop with explicit exiting and latch states + mlrsdfg = gr.OrderedDiGraph() + mlrsdfg.add_nodes_from([cls.entry_state, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.loop_latch, sd.InterstateEdge()) + + # Case 8: Guarded rotated multi-state loop with explicit exiting and latch states (modification of case 6) + gmlrmsdfg = gr.OrderedDiGraph() + gmlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + gmlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_begin, cls.loop_break, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg, mlrmsdfg, mlrsdfg, gmlrmsdfg] def can_be_applied(self, graph: ControlFlowRegion, @@ -78,15 +105,21 @@ def can_be_applied(self, sdfg: sd.SDFG, permissive: bool = False) -> bool: if expr_index == 0: - return self.detect_loop(graph, False, permissive) is not None + return self.detect_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 1: - return self.detect_loop(graph, True, permissive) is not None + return self.detect_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 2: - return self.detect_rotated_loop(graph, False, permissive) is not None + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 3: - return self.detect_rotated_loop(graph, True, permissive) is not None + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 4: - return self.detect_self_loop(graph, permissive) is not None + return self.detect_self_loop(graph, accept_missing_itvar=permissive) is not None + elif expr_index in (5, 7): + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive, + separate_latch=True) is not None + elif expr_index == 6: + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive, + separate_latch=True) is not None raise ValueError(f'Invalid expression index {expr_index}') @@ -173,7 +206,7 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, return next(iter(itvar)) def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, - accept_missing_itvar: bool = False) -> Optional[str]: + accept_missing_itvar: bool = False, separate_latch: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -189,6 +222,9 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, :return: The loop variable or ``None`` if not detected. """ latch = self.loop_latch + ltest = self.loop_latch + if separate_latch: + ltest = self.loop_break if multistate_loop else self.loop_begin begin = self.loop_begin # A for-loop start has at least two incoming edges (init and increment) @@ -196,7 +232,7 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if len(begin_inedges) < 2: return None # A for-loop latch only has two outgoing edges (loop condition and exit-loop) - latch_outedges = graph.out_edges(latch) + latch_outedges = graph.out_edges(ltest) if len(latch_outedges) != 2: return None @@ -212,8 +248,13 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, # All nodes inside loop must be dominated by loop start dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) - loop_nodes += [latch] + if begin is ltest: + loop_nodes = [begin] + else: + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes.append(latch) + if ltest is not latch and ltest is not begin: + loop_nodes.append(ltest) backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -235,33 +276,7 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if backedge is None: return None - # The iteration variable must be reassigned on all incoming edges to the latch block. - # If an assignment overlap of exactly one variable is found between the initialization edge and the edges - # going into the latch block, that will be the iteration variable. - itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() - itvar_edge_set.update(begin_inedges) - itvar_edge_set.update(latch_inedges) - itvar = None - for iedge in itvar_edge_set: - if iedge is backedge: - continue - if itvar is None: - itvar = set(iedge.data.assignments.keys()) - else: - itvar &= iedge.data.assignments.keys() - if itvar is None: - return None - if len(itvar) != 1: - if not accept_missing_itvar: - # Either no consistent iteration variable found, or too many consistent iteration variables found - return None - else: - if len(itvar) == 0: - return '' - else: - return None - - return next(iter(itvar)) + return rotated_loop_find_itvar(begin_inedges, latch_inedges, backedge, ltest, accept_missing_itvar)[0] def detect_self_loop(self, graph: ControlFlowRegion, accept_missing_itvar: bool = False) -> Optional[str]: """ @@ -338,9 +353,10 @@ def loop_information( if self.expr_index <= 1: guard = self.loop_guard return find_for_loop(guard.parent_graph, guard, entry, itervar) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch - return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar) + return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar, + separate_latch=(self.expr_index in (5, 6, 7))) elif self.expr_index == 4: return find_rotated_for_loop(entry.parent_graph, entry, entry, itervar) @@ -362,6 +378,14 @@ def loop_body(self) -> List[ControlFlowBlock]: return loop_nodes elif self.expr_index == 4: return [begin] + elif self.expr_index in (5, 7): + ltest = self.loop_break + latch = self.loop_latch + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes += [ltest, latch] + return loop_nodes + elif self.expr_index == 6: + return [begin, self.loop_latch] return [] @@ -371,8 +395,10 @@ def loop_meta_states(self) -> List[ControlFlowBlock]: """ if self.expr_index in (0, 1): return [self.loop_guard] - if self.expr_index in (2, 3): + if self.expr_index in (2, 3, 6): return [self.loop_latch] + if self.expr_index in (5, 7): + return [self.loop_break, self.loop_latch] return [] def loop_init_edge(self) -> gr.Edge[InterstateEdge]: @@ -385,7 +411,7 @@ def loop_init_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src not in body) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch return next(e for e in graph.in_edges(begin) if e.src is not latch) elif self.expr_index == 4: @@ -405,9 +431,12 @@ def loop_exit_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index in (2, 3): latch = self.loop_latch return graph.edges_between(latch, exitstate)[0] - elif self.expr_index == 4: + elif self.expr_index in (4, 6): begin = self.loop_begin return graph.edges_between(begin, exitstate)[0] + elif self.expr_index in (5, 7): + ltest = self.loop_break + return graph.edges_between(ltest, exitstate)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -426,6 +455,10 @@ def loop_condition_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index == 4: begin = self.loop_begin return graph.edges_between(begin, begin)[0] + elif self.expr_index in (5, 6, 7): + latch = self.loop_latch + ltest = self.loop_break if self.expr_index in (5, 7) else self.loop_begin + return graph.edges_between(ltest, latch)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -439,7 +472,7 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src in body) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): body = self.loop_body() return graph.in_edges(self.loop_latch)[0] elif self.expr_index == 4: @@ -448,6 +481,84 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: raise ValueError(f'Invalid expression index {self.expr_index}') +def rotated_loop_find_itvar(begin_inedges: List[gr.Edge[InterstateEdge]], + latch_inedges: List[gr.Edge[InterstateEdge]], + backedge: gr.Edge[InterstateEdge], latch: ControlFlowBlock, + accept_missing_itvar: bool = False) -> Tuple[Optional[str], + Optional[gr.Edge[InterstateEdge]]]: + # The iteration variable must be assigned (initialized) on all edges leading into the beginning block, which + # are not the backedge. Gather all variabes for which that holds - they are all candidates for the iteration + # variable (Phase 1). Said iteration variable must then be incremented: + # EITHER: On the backedge, in which case the increment is only executed if the loop does not exit. This + # corresponds to a while(true) loop that checks the condition at the end of the loop body and breaks + # if it does not hold before incrementing. (Scenario 1) + # OR: On the edge(s) leading into the latch, in which case the increment is executed BEFORE the condition is + # checked - which corresponds to a do-while loop. (Scenario 2) + # For either case, the iteration variable may only be incremented on one of these places. Filter the candidates + # down to each variable for which this condition holds (Phase 2). If there is exactly one candidate remaining, + # that is the iteration variable. Otherwise it cannot be determined. + + # Phase 1: Gather iteration variable candidates. + itvar_candidates = None + for e in begin_inedges: + if e is backedge: + continue + if itvar_candidates is None: + itvar_candidates = set(e.data.assignments.keys()) + else: + itvar_candidates &= set(e.data.assignments.keys()) + + # Phase 2: Filter down the candidates according to incrementation edges. + step_edge = None + filtered_candidates = set() + backedge_incremented = set(backedge.data.assignments.keys()) + latch_incremented = None + if backedge.src is not backedge.dst: + # If this is a self loop, there are no edges going into the latch to be considered. The only incoming edges are + # from outside the loop. + for e in latch_inedges: + if e is backedge: + continue + if latch_incremented is None: + latch_incremented = set(e.data.assignments.keys()) + else: + latch_incremented &= set(e.data.assignments.keys()) + if latch_incremented is None: + latch_incremented = set() + for cand in itvar_candidates: + if cand in backedge_incremented: + # Scenario 1. + + # TODO: Not sure if the condition below is a necessary prerequisite. + # Note, only allow this scenario if the backedge leads directly from the latch to the entry, i.e., there is + # no intermediate block on the backedge path. + if backedge.src is not latch: + continue + + if cand not in latch_incremented: + filtered_candidates.add(cand) + elif cand in latch_incremented: + # Scenario 2. + if cand not in backedge_incremented: + filtered_candidates.add(cand) + if len(filtered_candidates) != 1: + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None, None + else: + if len(filtered_candidates) == 0: + return '', None + else: + return None, None + else: + itvar = next(iter(filtered_candidates)) + if itvar in backedge_incremented: + step_edge = backedge + elif len(latch_inedges) == 1: + step_edge = latch_inedges[0] + return itvar, step_edge + + def find_for_loop( graph: ControlFlowRegion, guard: sd.SDFGState, @@ -548,6 +659,10 @@ def find_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None @@ -559,7 +674,8 @@ def find_rotated_for_loop( graph: ControlFlowRegion, latch: sd.SDFGState, entry: sd.SDFGState, - itervar: Optional[str] = None + itervar: Optional[str] = None, + separate_latch: bool = False, ) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ List[sd.SDFGState], sd.SDFGState]]]: """ @@ -574,29 +690,19 @@ def find_rotated_for_loop( """ # Extract state transition edge information entry_inedges = graph.in_edges(entry) - condition_edge = graph.edges_between(latch, entry)[0] + if separate_latch: + condition_edge = graph.in_edges(latch)[0] + backedge = graph.edges_between(latch, entry)[0] + else: + condition_edge = graph.edges_between(latch, entry)[0] + backedge = condition_edge latch_inedges = graph.in_edges(latch) self_loop = latch is entry + step_edge = None if itervar is None: - # The iteration variable must be reassigned on all incoming edges to the latch block. - # If an assignment overlap of exactly one variable is found between the initialization edge and the edges - # going into the latch block, that will be the iteration variable. - itvar_edge_set: Set[gr.Edge[InterstateEdge]] = set() - itvar_edge_set.update(entry_inedges) - itvar_edge_set.update(latch_inedges) - itervars = None - for iedge in itvar_edge_set: - if iedge is condition_edge and not self_loop: - continue - if itervars is None: - itervars = set(iedge.data.assignments.keys()) - else: - itervars &= iedge.data.assignments.keys() - if itervars and len(itervars) == 1: - itervar = next(iter(itervars)) - else: - # Ambiguous or no iteration variable + itervar, step_edge = rotated_loop_find_itvar(entry_inedges, latch_inedges, backedge, latch) + if itervar is None: return None condition = condition_edge.data.condition_sympy() @@ -628,9 +734,7 @@ def find_rotated_for_loop( if self_loop: step_edge = condition_edge else: - step_edge = None if len(latch_inedges) != 1 else latch_inedges[0] if step_edge is None: - # No explicit step edge found. return None # Get the init expression and the stride. @@ -664,6 +768,10 @@ def find_rotated_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 5cad716176..2f92ab4e86 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -7,7 +7,7 @@ @properties.make_properties @transformation.experimental_cfg_block_compatible -class ControlFlowRaising(ppl.Pass): +class ControlFlowLifting(ppl.Pass): CATEGORY: str = 'Simplification' diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 891d520f41..323a27787a 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -27,7 +27,8 @@ def tester(a: dace.float64[20]): assert rng == (1, 19, 1) -def test_loop_rotated(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_rotated(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -37,8 +38,12 @@ def test_loop_rotated(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() @@ -48,8 +53,9 @@ def test_loop_rotated(): assert rng == (0, dace.symbol('N') - 1, 2) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') def test_loop_rotated_extra_increment(): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -60,15 +66,13 @@ def test_loop_rotated_extra_increment(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) sdfg.add_edge(latch, increment, dace.InterstateEdge('i < N')) sdfg.add_edge(increment, body, dace.InterstateEdge(assignments=dict(i='i + 1'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 def test_self_loop(): @@ -91,7 +95,8 @@ def test_self_loop(): assert rng == (2, dace.symbol('N') - 1, 3) -def test_loop_llvm_canonical(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_llvm_canonical(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -106,8 +111,12 @@ def test_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) @@ -118,9 +127,10 @@ def test_loop_llvm_canonical(): assert rng == (0, dace.symbol('N') - 1, 1) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') @pytest.mark.parametrize('with_bounds_check', (False, True)) def test_loop_llvm_canonical_with_extras(with_bounds_check): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -148,17 +158,16 @@ def test_loop_llvm_canonical_with_extras(with_bounds_check): sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 if __name__ == '__main__': test_pyloop() - test_loop_rotated() - # test_loop_rotated_extra_increment() + test_loop_rotated(True) + test_loop_rotated(False) + test_loop_rotated_extra_increment() test_self_loop() - test_loop_llvm_canonical() - # test_loop_llvm_canonical_with_extras(False) - # test_loop_llvm_canonical_with_extras(True) + test_loop_llvm_canonical(True) + test_loop_llvm_canonical(False) + test_loop_llvm_canonical_with_extras(False) + test_loop_llvm_canonical_with_extras(True) From 3fbe26bbf5bf629b21c7b8f8b0616856818958f7 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 15:32:54 +0200 Subject: [PATCH 013/108] Loop detection and lifting fixes --- dace/codegen/control_flow.py | 29 ++++++++++--------- dace/codegen/targets/framecode.py | 13 ++++++++- dace/sdfg/state.py | 10 ++++++- .../interstate/loop_detection.py | 5 ++-- .../transformation/interstate/loop_lifting.py | 17 +++++++---- .../interstate/loop_lifting_test.py | 15 +++++++--- 6 files changed, 62 insertions(+), 27 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index cfa5c8d41d..d0cd3da8b4 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -539,26 +539,27 @@ def as_cpp(self, codegen, symbols) -> str: expr = '' if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - # Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined. - defined_vars = codegen.dispatcher.defined_vars - if not defined_vars.has(self.loop.loop_variable): - try: - init = f'{symbols[self.loop.loop_variable]} ' - except KeyError: - init = 'auto ' - symbols[self.loop.loop_variable] = None - init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) init = init.strip(';') update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols) update = update.strip(';') if self.loop.inverted: - expr += f'{init};\n' - expr += 'do {\n' - expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) - expr += f'{update};\n' - expr += f'\n}} while({cond});\n' + if self.loop.update_before_condition: + expr += f'{init};\n' + expr += 'do {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'{update};\n' + expr += f'}} while({cond});\n' + else: + expr += f'{init};\n' + expr += 'while (1) {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'if (!({cond}))\n' + expr += 'break;\n' + expr += f'{update};\n' + expr += '}\n' else: expr += f'for ({init}; {cond}; {update}) {{\n' expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 488c1c7fbd..2d3c524771 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,4 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import ast import collections import copy import re @@ -15,11 +16,12 @@ from dace.codegen.prettycode import CodeIOStream from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.transformation.passes.analysis import StateReachability @@ -916,6 +918,15 @@ def generate_code(self, interstate_symbols.update(symbols) global_symbols.update(symbols) + if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: + init_assignment = cfr.init_statement.code[0] + if isinstance(init_assignment, ast.Assign): + init_assignment = init_assignment.value + if not cfr.loop_variable in interstate_symbols: + interstate_symbols[cfr.loop_variable] = infer_expr_type(ast.unparse(init_assignment)) + if not cfr.loop_variable in global_symbols: + global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] + for isvarName, isvarType in interstate_symbols.items(): if isvarType is None: raise TypeError(f'Type inference failed for symbol {isvarName}') diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 22ac601da1..4dc93a8d9d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2987,6 +2987,12 @@ class LoopRegion(ControlFlowRegion): inverted = Property(dtype=bool, default=False, desc='If True, the loop condition is checked after the first iteration.') + update_before_condition = Property(dtype=bool, + default=True, + desc='If False, the loop condition is checked before the update statement is' + + ' executed. This only applies to inverted loops, turning them from a typical ' + + 'do-while style into a while(true) with a break before the update (at the end ' + + 'of an iteration)if the condition no longer holds.') loop_variable = Property(dtype=str, default='', desc='The loop variable, if given') def __init__(self, @@ -2996,7 +3002,8 @@ def __init__(self, initialize_expr: Optional[Union[str, CodeBlock]] = None, update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, - sdfg: Optional['SDFG'] = None): + sdfg: Optional['SDFG'] = None, + update_before_condition = True): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: @@ -3025,6 +3032,7 @@ def __init__(self, self.loop_variable = loop_var or '' self.inverted = inverted + self.update_before_condition = update_before_condition def inline(self) -> Tuple[bool, Any]: """ diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index daf13599fe..bd65cec290 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -473,8 +473,9 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src in body) elif self.expr_index in (2, 3, 5, 6, 7): - body = self.loop_body() - return graph.in_edges(self.loop_latch)[0] + _, step_edge = rotated_loop_find_itvar(graph.in_edges(begin), graph.in_edges(self.loop_latch), + graph.edges_between(self.loop_latch, begin)[0], self.loop_latch) + return step_edge elif self.expr_index == 4: return graph.edges_between(begin, begin)[0] diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 54363dd8e2..604aa74d16 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -36,7 +36,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): full_body.update(meta) cond_edge = self.loop_condition_edge() incr_edge = self.loop_increment_edge() - inverted = self.expr_index in (2, 3) + inverted = self.expr_index in (2, 3, 5, 6, 7) init_edge = self.loop_init_edge() exit_edge = self.loop_exit_edge() @@ -55,12 +55,19 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): if k != itvar: left_over_assignments[k] = init_edge.data.assignments[k] left_over_incr_assignments = {} - for k in incr_edge.data.assignments.keys(): - if k != itvar: - left_over_incr_assignments[k] = incr_edge.data.assignments[k] + if incr_edge is not None: + for k in incr_edge.data.assignments.keys(): + if k != itvar: + left_over_incr_assignments[k] = incr_edge.data.assignments[k] + + if inverted and incr_edge is cond_edge: + update_before_condition = False + else: + update_before_condition = True loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, - update_expr=incr_expr, inverted=inverted, sdfg=sdfg) + update_expr=incr_expr, inverted=inverted, sdfg=sdfg, + update_before_condition=update_before_condition) graph.add_node(loop) graph.add_edge(init_edge.src, loop, diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index e3098d4e5c..843209794f 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -2,6 +2,7 @@ """ Tests loop raising trainsformations. """ import numpy as np +import pytest import dace from dace.memlet import Memlet from dace.sdfg.sdfg import SDFG, InterstateEdge @@ -52,7 +53,8 @@ def test_lift_regular_for_loop(): assert np.allclose(A_valid, A) -def test_lift_loop_llvm_canonical(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_lift_loop_llvm_canonical(increment_before_condition): sdfg = dace.SDFG('llvm_canonical') N = dace.symbol('N') sdfg.add_symbol('i', dace.int32) @@ -72,8 +74,12 @@ def test_lift_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) - sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) - sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + if increment_before_condition: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + else: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2'})) sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) sdfg.add_edge(loopexit, exitstate, InterstateEdge()) @@ -160,5 +166,6 @@ def test_lift_loop_llvm_canonical_while(): if __name__ == '__main__': test_lift_regular_for_loop() - test_lift_loop_llvm_canonical() + test_lift_loop_llvm_canonical(True) + test_lift_loop_llvm_canonical(False) test_lift_loop_llvm_canonical_while() From 2bd5d007b82cdaddc818a9cacada5820b36d1150 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 2 Oct 2024 17:38:09 +0200 Subject: [PATCH 014/108] Work on conditional lifting --- dace/codegen/control_flow.py | 7 +- dace/codegen/targets/framecode.py | 6 +- .../simplification/control_flow_raising.py | 77 ++++++++++++++- .../control_flow_raising_test.py | 98 +++++++++++++++++++ 4 files changed, 177 insertions(+), 11 deletions(-) create mode 100644 tests/passes/simplification/control_flow_raising_test.py diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index d0cd3da8b4..f5559984e7 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -274,14 +274,11 @@ def as_cpp(self, codegen, symbols) -> str: for i, elem in enumerate(self.elements): expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. - if (isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region) or - (isinstance(elem, GeneralLoopScope) and elem.loop)): + if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region): if isinstance(elem, BasicCFBlock): g_elem = elem.state - elif isinstance(elem, GeneralBlock): - g_elem = elem.region else: - g_elem = elem.loop + g_elem = elem.region cfg = g_elem.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg out_edges = cfg.out_edges(g_elem) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 2d3c524771..f7b8338269 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,5 +1,4 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import ast import collections import copy import re @@ -17,6 +16,7 @@ from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator from dace.codegen.tools.type_inference import infer_expr_type +from dace.frontend.python import astutils from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils @@ -920,10 +920,10 @@ def generate_code(self, if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: init_assignment = cfr.init_statement.code[0] - if isinstance(init_assignment, ast.Assign): + if isinstance(init_assignment, astutils.ast.Assign): init_assignment = init_assignment.value if not cfr.loop_variable in interstate_symbols: - interstate_symbols[cfr.loop_variable] = infer_expr_type(ast.unparse(init_assignment)) + interstate_symbols[cfr.loop_variable] = infer_expr_type(astutils.unparse(init_assignment)) if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 2f92ab4e86..5cfd6ffba6 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -1,13 +1,19 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Optional, Tuple +import networkx as nx from dace import properties +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.utils import dfs_conditional from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.interstate.loop_lifting import LoopLifting @properties.make_properties @transformation.experimental_cfg_block_compatible -class ControlFlowLifting(ppl.Pass): +class ControlFlowRaising(ppl.Pass): CATEGORY: str = 'Simplification' @@ -17,6 +23,71 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG - def apply_pass(self, top_sdfg: ppl.SDFG, _) -> ppl.Any | None: + def _lift_conditionals(self, sdfg: SDFG) -> int: + cfgs = list(sdfg.all_control_flow_regions()) + n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + + for region in cfgs: + dummy_exit = region.add_state('__DACE_DUMMY') + for s in region.sink_nodes(): + if s is not dummy_exit: + region.add_edge(s, dummy_exit, InterstateEdge()) + idom = nx.immediate_dominators(region.nx, region.start_block) + alldoms = cfg_analysis.all_dominators(region, idom) + branch_merges = cfg_analysis.branch_merges(region, idom, alldoms) + + for block in region.nodes(): + graph = block.parent_graph + oedges = graph.out_edges(block) + if len(oedges) > 1 and block in branch_merges: + merge_block = branch_merges[block] + + # Construct the branching block. + conditional = ConditionalBlock('conditional_' + block.label, sdfg, graph) + graph.add_node(conditional) + # Connect it. + graph.add_edge(block, conditional, InterstateEdge()) + + # Populate branches. + for i, oe in enumerate(oedges): + branch_name = 'branch_' + str(i) + '_' + block.label + branch = ControlFlowRegion(branch_name, sdfg) + conditional.branches.append([oe.data.condition, branch]) + if oe.dst is merge_block: + # Empty branch. + continue + + branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) + branch_start = branch.add_state(branch_name + '_start', is_start_block=True) + branch.add_nodes_from(branch_nodes) + branch_end = branch.add_state(branch_name + '_end') + branch.add_edge(branch_start, oe.dst, InterstateEdge(assignments=oe.data.assignments)) + added = set() + for e in graph.all_edges(*branch_nodes): + if not (e in added): + added.add(e) + if e is oe: + continue + elif e.dst is merge_block: + branch.add_edge(e.src, branch_end, e.data) + else: + branch.add_edge(e.src, e.dst, e.data) + graph.remove_nodes_from(branch_nodes) + + # Connect to the end of the branch / what happens after. + if merge_block is not dummy_exit: + graph.add_edge(conditional, merge_block, InterstateEdge()) + region.remove_node(dummy_exit) + + n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + return n_cond_regions_post - n_cond_regions_pre + + def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: + lifted_loops = 0 + lifted_branches = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): - sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) + lifted_loops += sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) + lifted_branches += self._lift_conditionals(sdfg) + if lifted_branches == 0 and lifted_loops == 0: + return None + return lifted_loops, lifted_branches diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py new file mode 100644 index 0000000000..53e01df12f --- /dev/null +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -0,0 +1,98 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +from dace.sdfg.state import ConditionalBlock +from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising + + +def test_dataflow_if_check(): + + @dace.program + def dataflow_if_check(A: dace.int32[10], i: dace.int64): + if A[i] < 10: + return 0 + elif A[i] == 10: + return 10 + return 100 + + sdfg = dataflow_if_check.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.__experimental_cfg_block_compatible__ = True + ppl.apply_pass(sdfg, {}) + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + A = np.zeros((10,), np.int32) + A[4] = 10 + A[5] = 100 + assert sdfg(A, 0)[0] == 0 + assert sdfg(A, 4)[0] == 10 + assert sdfg(A, 5)[0] == 100 + assert sdfg(A, 6)[0] == 0 + + +def test_nested_if_chain(): + + @dace.program + def nested_if_chain(i: dace.int64): + if i < 2: + return 0 + else: + if i < 4: + return 1 + else: + if i < 6: + return 2 + else: + if i < 8: + return 3 + else: + return 4 + + sdfg = nested_if_chain.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert nested_if_chain(0)[0] == 0 + assert nested_if_chain(2)[0] == 1 + assert nested_if_chain(4)[0] == 2 + assert nested_if_chain(7)[0] == 3 + assert nested_if_chain(15)[0] == 4 + + +def test_elif_chain(): + + @dace.program + def elif_chain(i: dace.int64): + if i < 2: + return 0 + elif i < 4: + return 1 + elif i < 6: + return 2 + elif i < 8: + return 3 + else: + return 4 + + elif_chain.use_experimental_cfg_blocks = True + sdfg = elif_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert elif_chain(0)[0] == 0 + assert elif_chain(2)[0] == 1 + assert elif_chain(4)[0] == 2 + assert elif_chain(7)[0] == 3 + assert elif_chain(15)[0] == 4 + + +if __name__ == '__main__': + test_dataflow_if_check() + test_nested_if_chain() + test_elif_chain() From dc640a95b7bc1afe65d8880db3ee46f000b51642 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 7 Oct 2024 09:40:08 +0200 Subject: [PATCH 015/108] Improve conditional block interface --- dace/frontend/python/newast.py | 7 +-- dace/sdfg/state.py | 7 ++- .../simplification/control_flow_raising.py | 2 +- tests/sdfg/conditional_region_test.py | 50 +++++++++---------- 4 files changed, 34 insertions(+), 32 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 0d40e13282..cacf15d785 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2565,8 +2565,7 @@ def visit_If(self, node: ast.If): self._on_block_added(cond_block) if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) - cond_block.branches.append((CodeBlock(cond), if_body)) - if_body.parent_graph = self.cfg_target + cond_block.add_branch(CodeBlock(cond), if_body) # Visit recursively self._recursive_visit(node.body, 'if', node.lineno, if_body, False) @@ -2575,9 +2574,7 @@ def visit_If(self, node: ast.If): if len(node.orelse) > 0: else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) - #cond_block.branches.append((CodeBlock(cond_else), else_body)) - cond_block.branches.append((None, else_body)) - else_body.parent_graph = self.cfg_target + cond_block.add_branch(None, else_body) # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 4dc93a8d9d..4460802a62 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3251,7 +3251,12 @@ def __repr__(self) -> str: @property def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: return self._branches - + + def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): + self._branches.append([condition, branch]) + branch.parent_graph = self.parent_graph + branch.sdfg = self.sdfg + def nodes(self) -> List['ControlFlowBlock']: return [node for _, node in self._branches if node is not None] diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 5cfd6ffba6..c6aced2a6d 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -52,7 +52,7 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: for i, oe in enumerate(oedges): branch_name = 'branch_' + str(i) + '_' + block.label branch = ControlFlowRegion(branch_name, sdfg) - conditional.branches.append([oe.data.condition, branch]) + conditional.add_branch(oe.data.condition, branch) if oe.dst is merge_block: # Empty branch. continue diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 4e4eda3f44..a6a46bc568 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -10,20 +10,20 @@ def test_cond_region_if(): sdfg = dace.SDFG('regular_if') - sdfg.add_array("A", (1,), dace.float32) - sdfg.add_symbol("i", dace.int32) + sdfg.add_array('A', (1,), dace.float32) + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalBlock("if1") + if1 = ConditionalBlock('if1') sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) - if_body = ControlFlowRegion("if_body", sdfg=sdfg) - if1.branches.append((CodeBlock("i == 1"), if_body)) + if_body = ControlFlowRegion('if_body', sdfg=sdfg) + if1.add_branch((CodeBlock('i == 1'), if_body)) - state1 = if_body.add_state("state1", is_start_block=True) + state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') - t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = 100') state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) assert sdfg.is_valid() @@ -36,14 +36,14 @@ def test_cond_region_if(): assert A[0] == 1 def test_serialization(): - sdfg = SDFG("test_serialization") - cond_region = ConditionalBlock("cond_region") + sdfg = SDFG('test_serialization') + cond_region = ConditionalBlock('cond_region') sdfg.add_node(cond_region, is_start_block=True) - sdfg.add_symbol("i", dace.int32) + sdfg.add_symbol('i', dace.int32) for j in range(10): - cfg = ControlFlowRegion(f"cfg_{j}", sdfg) - cond_region.branches.append((CodeBlock(f"i == {j}"), cfg)) + cfg = ControlFlowRegion(f'cfg_{j}', sdfg) + cond_region.add_branch(CodeBlock(f'i == {j}'), cfg) assert sdfg.is_valid() @@ -52,32 +52,32 @@ def test_serialization(): new_cond_region: ConditionalBlock = new_sdfg.nodes()[0] for j in range(10): condition, cfg = new_cond_region.branches[j] - assert condition == CodeBlock(f"i == {j}") - assert cfg.label == f"cfg_{j}" + assert condition == CodeBlock(f'i == {j}') + assert cfg.label == f'cfg_{j}' def test_if_else(): sdfg = dace.SDFG('regular_if_else') - sdfg.add_array("A", (1,), dace.float32) - sdfg.add_symbol("i", dace.int32) + sdfg.add_array('A', (1,), dace.float32) + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalBlock("if1") + if1 = ConditionalBlock('if1') sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) - if_body = ControlFlowRegion("if_body", sdfg=sdfg) - state1 = if_body.add_state("state1", is_start_block=True) + if_body = ControlFlowRegion('if_body', sdfg=sdfg) + state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') - t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = 100') state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) - if1.branches.append((CodeBlock("i == 1"), if_body)) + if1.add_branch(CodeBlock('i == 1'), if_body) - else_body = ControlFlowRegion("else_body", sdfg=sdfg) - state2 = else_body.add_state("state1", is_start_block=True) + else_body = ControlFlowRegion('else_body', sdfg=sdfg) + state2 = else_body.add_state('state1', is_start_block=True) acc_a2 = state2.add_access('A') - t2 = state2.add_tasklet("t2", None, {"a"}, "a = 200") + t2 = state2.add_tasklet('t2', None, {'a'}, 'a = 200') state2.add_edge(t2, 'a', acc_a2, None, dace.Memlet('A[0]')) - if1.branches.append((CodeBlock("i == 0"), else_body)) + if1.add_branch(CodeBlock('i == 0'), else_body) assert sdfg.is_valid() A = np.ones((1,), dtype=np.float32) From 02d53fc3362384f47dce08c5442908bdd6af95ac Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 7 Oct 2024 09:51:19 +0200 Subject: [PATCH 016/108] Remove files from other PR --- .../analysis/control_flow_region_analysis.py | 227 ------------------ .../passes/analysis/loop_analysis.py | 213 ---------------- .../control_flow_region_analysis_test.py | 117 --------- 3 files changed, 557 deletions(-) delete mode 100644 dace/transformation/passes/analysis/control_flow_region_analysis.py delete mode 100644 dace/transformation/passes/analysis/loop_analysis.py delete mode 100644 tests/passes/analysis/control_flow_region_analysis_test.py diff --git a/dace/transformation/passes/analysis/control_flow_region_analysis.py b/dace/transformation/passes/analysis/control_flow_region_analysis.py deleted file mode 100644 index 377765c31b..0000000000 --- a/dace/transformation/passes/analysis/control_flow_region_analysis.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. - -from collections import defaultdict -from typing import Any, Dict, List, Set, Tuple - -import networkx as nx - -from dace import SDFG, SDFGState -from dace import data as dt -from dace import properties -from dace.memlet import Memlet -from dace.sdfg import nodes, propagation -from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDictT -from dace.sdfg.graph import MultiConnectorEdge -from dace.sdfg.scope import ScopeTree -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion -from dace.subsets import Range -from dace.transformation import pass_pipeline as ppl, transformation -from dace.transformation.passes.analysis.analysis import AccessRanges, ControlFlowBlockReachability - - -@properties.make_properties -@transformation.experimental_cfg_block_compatible -class StateDataDependence(ppl.Pass): - """ - Analyze the input dependencies and the underapproximated outputs of states. - """ - - CATEGORY: str = 'Analysis' - - def modifies(self) -> ppl.Modifies: - return ppl.Modifies.Nothing - - def should_reapply(self, modified: ppl.Modifies) -> bool: - return modified & (ppl.Modifies.Nodes | ppl.Modifies.Memlets) - - def depends_on(self): - return {UnderapproximateWrites} - - def _gather_reads_scope(self, state: SDFGState, scope: ScopeTree, - writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]], - not_covered_reads: Set[Memlet], scope_ranges: Dict[str, Range]): - scope_nodes = state.scope_children()[scope.entry] - data_nodes_in_scope: Set[nodes.AccessNode] = set([n for n in scope_nodes if isinstance(nodes.AccessNode)]) - if scope.entry is not None: - # propagate - pass - - for anode in data_nodes_in_scope: - for oedge in state.out_edges(anode): - if not oedge.data.is_empty(): - root_edge = state.memlet_tree(oedge).root().edge - read_subset = root_edge.data.src_subset - covered = False - for [write, to] in writes[anode.data]: - if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): - covered = True - break - if not covered: - not_covered_reads.add(root_edge.data) - - def _state_get_deps(self, state: SDFGState, - underapproximated_writes: UnderapproximateWritesDictT) -> Tuple[Set[Memlet], Set[Memlet]]: - # Collect underapproximated write memlets. - writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]] = defaultdict(lambda: []) - for anode in state.data_nodes(): - for iedge in state.in_edges(anode): - if not iedge.data.is_empty(): - root_edge = state.memlet_tree(iedge).root().edge - if root_edge in underapproximated_writes['approximation']: - writes[anode.data].append([underapproximated_writes['approximation'][root_edge], anode]) - else: - writes[anode.data].append([root_edge.data, anode]) - - # Go over (overapproximated) reads and check if they are covered by writes. - not_covered_reads: List[Tuple[MultiConnectorEdge[Memlet], Memlet]] = [] - for anode in state.data_nodes(): - for oedge in state.out_edges(anode): - if not oedge.data.is_empty(): - if oedge.data.data != anode.data: - # Special case for memlets copying data out of the scope, which are by default aligned with the - # outside data container. In this case, the source container must either be a scalar, or the - # read subset is contained in the memlet's `other_subset` property. - # See `dace.sdfg.propagation.align_memlet` for more. - desc = state.sdfg.data(anode.data) - if oedge.data.other_subset is not None: - read_subset = oedge.data.other_subset - elif isinstance(desc, dt.Scalar) or (isinstance(desc, dt.Array) and desc.total_size == 1): - read_subset = Range([(0, 0, 1)] * len(desc.shape)) - else: - raise RuntimeError('Invalid memlet range detected in StateDataDependence analysis') - else: - read_subset = oedge.data.src_subset or oedge.data.subset - covered = False - for [write, to] in writes[anode.data]: - if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): - covered = True - break - if not covered: - not_covered_reads.append([oedge, oedge.data]) - # Make sure all reads are propagated if they happen inside maps. We do not need to do this for writes, because - # it is already taken care of by the write underapproximation analysis pass. - self._recursive_propagate_reads(state, state.scope_tree()[None], not_covered_reads) - - write_set = set() - for data in writes: - for memlet, _ in writes[data]: - write_set.add(memlet) - - read_set = set() - for reads in not_covered_reads: - read_set.add(reads[1]) - - return read_set, write_set - - def _recursive_propagate_reads(self, state: SDFGState, scope: ScopeTree, - read_edges: Set[Tuple[MultiConnectorEdge[Memlet], Memlet]]): - for child in scope.children: - self._recursive_propagate_reads(state, child, read_edges) - - if scope.entry is not None: - if isinstance(scope.entry, nodes.MapEntry): - for read_tuple in read_edges: - read_edge, read_memlet = read_tuple - for param in scope.entry.map.params: - if param in read_memlet.free_symbols: - aligned_memlet = propagation.align_memlet(state, read_edge, True) - propagated_memlet = propagation.propagate_memlet(state, aligned_memlet, scope.entry, True) - read_tuple[1] = propagated_memlet - - def apply_pass(self, top_sdfg: SDFG, - pipeline_results: Dict[str, Any]) -> Dict[int, Dict[SDFGState, Tuple[Set[Memlet], Set[Memlet]]]]: - """ - :return: For each SDFG, a dictionary mapping states to sets of their input and output memlets. - """ - - results = defaultdict(lambda: defaultdict(lambda: [set(), set()])) - - underapprox_writes_dict: Dict[int, Any] = pipeline_results[UnderapproximateWrites.__name__] - for sdfg in top_sdfg.all_sdfgs_recursive(): - uapprox_writes = underapprox_writes_dict[sdfg.cfg_id] - for state in sdfg.states(): - input_dependencies, output_dependencies = self._state_get_deps(state, uapprox_writes) - results[sdfg.cfg_id][state] = [input_dependencies, output_dependencies] - - return results - - -@properties.make_properties -@transformation.experimental_cfg_block_compatible -class CFGDataDependence(ppl.Pass): - """ - Analyze the input dependencies and the underapproximated outputs of control flow graphs / regions. - """ - - CATEGORY: str = 'Analysis' - - def modifies(self) -> ppl.Modifies: - return ppl.Modifies.Nothing - - def should_reapply(self, modified: ppl.Modifies) -> bool: - return modified & ppl.Modifies.CFG - - def depends_on(self): - return {StateDataDependence, ControlFlowBlockReachability} - - def _recursive_get_deps_region(self, cfg: ControlFlowRegion, - results: Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]], - state_deps: Dict[int, Dict[SDFGState, Tuple[Set[Memlet], Set[Memlet]]]], - cfg_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] - ) -> Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]: - # Collect all individual reads and writes happening inside the region. - region_reads: Dict[str, List[Tuple[Memlet, ControlFlowBlock]]] = defaultdict(list) - region_writes: Dict[str, List[Tuple[Memlet, ControlFlowBlock]]] = defaultdict(list) - for node in cfg.nodes(): - if isinstance(node, SDFGState): - for read in state_deps[node.sdfg.cfg_id][node][0]: - region_reads[read.data].append([read, node]) - for write in state_deps[node.sdfg.cfg_id][node][1]: - region_writes[write.data].append([write, node]) - elif isinstance(node, ControlFlowRegion): - sub_reads, sub_writes = self._recursive_get_deps_region(node, results, state_deps, cfg_reach) - for data in sub_reads: - for read in sub_reads[data]: - region_reads[data].append([read, node]) - for data in sub_writes: - for write in sub_writes[data]: - region_writes[data].append([write, node]) - - # Through reachability analysis, check which writes cover which reads. - # TODO: make sure this doesn't cover up reads if we have a cycle in the CFG. - not_covered_reads: Dict[str, Set[Memlet]] = defaultdict(set) - for data in region_reads: - for read, read_block in region_reads[data]: - covered = False - for write, write_block in region_writes[data]: - if (write.subset.covers_precise(read.src_subset or read.subset) and - write_block is not read_block and - nx.has_path(cfg.nx, write_block, read_block)): - covered = True - break - if not covered: - not_covered_reads[data].add(read) - - write_set: Dict[str, Set[Memlet]] = defaultdict(set) - for data in region_writes: - for memlet, _ in region_writes[data]: - write_set[data].add(memlet) - - results[cfg.cfg_id] = [not_covered_reads, write_set] - - return not_covered_reads, write_set - - def apply_pass(self, top_sdfg: SDFG, - pipeline_res: Dict[str, Any]) -> Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]]: - """ - :return: For each CFG, a dictionary mapping control flow regions to sets of their input and output memlets. - """ - - results = defaultdict(lambda: defaultdict(lambda: [defaultdict(set), defaultdict(set)])) - - state_deps_dict = pipeline_res[StateDataDependence.__name__] - cfb_reachability_dict = pipeline_res[ControlFlowBlockReachability.__name__] - for sdfg in top_sdfg.all_sdfgs_recursive(): - self._recursive_get_deps_region(sdfg, results, state_deps_dict, cfb_reachability_dict) - - return results diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py deleted file mode 100644 index dd8c5f7446..0000000000 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. - -import ast -from collections import defaultdict -from typing import Any, Dict, Optional, Set, Tuple - -import sympy - -from dace import SDFG, properties, symbolic, transformation -from dace.memlet import Memlet -from dace.sdfg.state import LoopRegion -from dace.subsets import Range, SubsetUnion -from dace.transformation import pass_pipeline as ppl -from dace.transformation.pass_pipeline import Pass -from dace.transformation.passes.analysis.control_flow_region_analysis import CFGDataDependence - - -@properties.make_properties -@transformation.experimental_cfg_block_compatible -class LoopCarryDependencyAnalysis(ppl.Pass): - """ - Analyze the data dependencies between loop iterations for loop regions. - """ - - CATEGORY: str = 'Analysis' - - _non_analyzable_loops: Set[LoopRegion] - - def __init__(self): - self._non_analyzable_loops = set() - super().__init__() - - def modifies(self) -> ppl.Modifies: - return ppl.Modifies.Nothing - - def should_reapply(self, modified: ppl.Modifies) -> bool: - return modified & ppl.Modifies.CFG - - def depends_on(self): - return {CFGDataDependence} - - def _intersects(self, loop: LoopRegion, write_subset: Range, read_subset: Range, update: sympy.Basic) -> bool: - """ - Check if a write subset intersects a read subset after being offset by the loop stride. The offset is performed - based on the symbolic loop update assignment expression. - """ - offset = update - symbolic.symbol(loop.loop_variable) - offset_list = [] - for i in range(write_subset.dims()): - if loop.loop_variable in write_subset.get_free_symbols_by_indices([i]): - offset_list.append(offset) - else: - offset_list.append(0) - offset_write = write_subset.offset_new(offset_list, True) - return offset_write.intersects(read_subset) - - def apply_pass(self, top_sdfg: SDFG, - pipeline_results: Dict[str, Any]) -> Dict[int, Dict[LoopRegion, Dict[Memlet, Set[Memlet]]]]: - """ - :return: For each SDFG, a dictionary mapping loop regions to a dictionary that resolves reads to writes in the - same loop, from which they may carry a RAW dependency. - """ - results = defaultdict(lambda: defaultdict(dict)) - - cfg_dependency_dict: Dict[int, Tuple[Dict[str, Set[Memlet]], Dict[str, Set[Memlet]]]] = pipeline_results[ - CFGDataDependence.__name__ - ] - for cfg in top_sdfg.all_control_flow_regions(recursive=True): - if isinstance(cfg, LoopRegion): - loop_inputs, loop_outputs = cfg_dependency_dict[cfg.cfg_id] - update_assignment = None - loop_dependencies: Dict[Memlet, Set[Memlet]] = dict() - - for data in loop_inputs: - if not data in loop_outputs: - continue - - for input in loop_inputs[data]: - read_subset = input.src_subset or input.subset - dep_candidates: Set[Memlet] = set() - if cfg.loop_variable and cfg.loop_variable in input.free_symbols: - # If the iteration variable is involved in an access, we need to first offset it by the loop - # stride and then check for an overlap/intersection. If one is found after offsetting, there - # is a RAW loop carry dependency. - for output in loop_outputs[data]: - # Get and cache the update assignment for the loop. - if update_assignment is None and not cfg in self._non_analyzable_loops: - update_assignment = get_update_assignment(cfg) - if update_assignment is None: - self._non_analyzable_loops(cfg) - - if isinstance(output.subset, SubsetUnion): - if any([self._intersects(cfg, s, read_subset, update_assignment) - for s in output.subset.subset_list]): - dep_candidates.add(output) - elif self._intersects(cfg, output.subset, read_subset, update_assignment): - dep_candidates.add(output) - else: - # Check for basic overlaps/intersections in RAW loop carry dependencies, when there is no - # iteration variable involved. - for output in loop_outputs[data]: - if isinstance(output.subset, SubsetUnion): - if any([s.intersects(read_subset) for s in output.subset.subset_list]): - dep_candidates.add(output) - elif output.subset.intersects(read_subset): - dep_candidates.add(output) - loop_dependencies[input] = dep_candidates - results[cfg.sdfg.cfg_id][cfg] = loop_dependencies - - return results - - -class FindAssignment(ast.NodeVisitor): - - assignments: Dict[str, str] - multiple: bool - - def __init__(self): - self.assignments = {} - self.multiple = False - - def visit_Assign(self, node: ast.Assign) -> Any: - for tgt in node.targets: - if isinstance(tgt, ast.Name): - if tgt.id in self.assignments: - self.multiple = True - self.assignments[tgt.id] = ast.unparse(node.value) - return self.generic_visit(node) - - -def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: - """ - Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). - """ - end: Optional[symbolic.SymbolicType] = None - a = sympy.Wild('a') - condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) - itersym = symbolic.pystr_to_symbolic(loop.loop_variable) - match = condition.match(itersym < a) - if match: - end = match[a] - 1 - if end is None: - match = condition.match(itersym <= a) - if match: - end = match[a] - if end is None: - match = condition.match(itersym > a) - if match: - end = match[a] + 1 - if end is None: - match = condition.match(itersym >= a) - if match: - end = match[a] - return end - - -def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: - """ - Parse a loop region's init statement to identify the exact init assignment expression. - """ - init_stmt = loop.init_statement - if init_stmt is None: - return None - - init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] - assignments: Dict[str, str] = {} - for code in init_codes_list: - visitor = FindAssignment() - visitor.visit(code) - if visitor.multiple: - return None - for assign in visitor.assignments: - if assign in assignments: - return None - assignments[assign] = visitor.assignments[assign] - - if loop.loop_variable in assignments: - return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) - - return None - - -def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: - """ - Parse a loop region's update statement to identify the exact update assignment expression. - """ - update_stmt = loop.update_statement - if update_stmt is None: - return None - - update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] - assignments: Dict[str, str] = {} - for code in update_codes_list: - visitor = FindAssignment() - visitor.visit(code) - if visitor.multiple: - return None - for assign in visitor.assignments: - if assign in assignments: - return None - assignments[assign] = visitor.assignments[assign] - - if loop.loop_variable in assignments: - return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) - - return None - - -def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: - update_assignment = get_update_assignment(loop) - if update_assignment: - return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) - return None diff --git a/tests/passes/analysis/control_flow_region_analysis_test.py b/tests/passes/analysis/control_flow_region_analysis_test.py deleted file mode 100644 index bf0742f3f1..0000000000 --- a/tests/passes/analysis/control_flow_region_analysis_test.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests analysis passes related to control flow regions (control_flow_region_analysis.py). """ - -import dace -from dace.memlet import Memlet -from dace.sdfg.sdfg import SDFG -from dace.transformation.pass_pipeline import Pipeline -from dace.transformation.passes.analysis.control_flow_region_analysis import StateDataDependence - - -def test_state_data_dependence_with_contained_read(): - sdfg = SDFG('myprog') - N = dace.symbol('N') - sdfg.add_array('A', (N, ), dace.float32) - sdfg.add_array('B', (N, ), dace.float32) - mystate = sdfg.add_state('mystate', is_start_block=True) - b_read = mystate.add_access('B') - b_write_second_half = mystate.add_access('B') - b_write_first_half = mystate.add_access('B') - a_read_write = mystate.add_access('A') - first_entry, first_exit = mystate.add_map('map_one', {'i': '0:0.5*N'}) - second_entry, second_exit = mystate.add_map('map_two', {'i': '0:0.5*N'}) - t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 + 1.0') - t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') - t3 = mystate.add_tasklet('t3', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') - mystate.add_memlet_path(b_read, first_entry, t1, memlet=Memlet('B[i]'), dst_conn='i1') - mystate.add_memlet_path(b_read, first_entry, t2, memlet=Memlet('B[i]'), dst_conn='i1') - mystate.add_memlet_path(t1, first_exit, a_read_write, memlet=Memlet('A[i]'), src_conn='o1') - mystate.add_memlet_path(t2, first_exit, b_write_second_half, memlet=Memlet('B[N - (i + 1)]'), src_conn='o1') - mystate.add_memlet_path(a_read_write, second_entry, t3, memlet=Memlet('A[i]'), dst_conn='i1') - mystate.add_memlet_path(t3, second_exit, b_write_first_half, memlet=Memlet('B[i]'), src_conn='o1') - - res = {} - Pipeline([StateDataDependence()]).apply_pass(sdfg, res) - state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] - - assert len(state_data_deps[0]) == 1 - read_memlet: Memlet = list(state_data_deps[0])[0] - assert read_memlet.data == 'B' - assert read_memlet.subset[0][0] == 0 - assert read_memlet.subset[0][1] == 0.5 * N - 1 - - assert len(state_data_deps[1]) == 3 - - -def test_state_data_dependence_with_contained_read_in_map(): - sdfg = SDFG('myprog') - N = dace.symbol('N') - sdfg.add_array('A', (N, ), dace.float32) - sdfg.add_transient('tmp', (N, ), dace.float32) - sdfg.add_array('B', (N, ), dace.float32) - mystate = sdfg.add_state('mystate', is_start_block=True) - a_read = mystate.add_access('A') - tmp = mystate.add_access('tmp') - b_write = mystate.add_access('B') - m_entry, m_exit = mystate.add_map('my_map', {'i': 'N'}) - t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') - t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') - mystate.add_memlet_path(a_read, m_entry, t1, memlet=Memlet('A[i]'), dst_conn='i1') - mystate.add_memlet_path(t1, tmp, memlet=Memlet('tmp[i]'), src_conn='o1') - mystate.add_memlet_path(tmp, t2, memlet=Memlet('tmp[i]'), dst_conn='i1') - mystate.add_memlet_path(t2, m_exit, b_write, memlet=Memlet('B[i]'), src_conn='o1') - - res = {} - Pipeline([StateDataDependence()]).apply_pass(sdfg, res) - state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] - - assert len(state_data_deps[0]) == 1 - read_memlet: Memlet = list(state_data_deps[0])[0] - assert read_memlet.data == 'A' - - assert len(state_data_deps[1]) == 2 - out_containers = [m.data for m in state_data_deps[1]] - assert 'B' in out_containers - assert 'tmp' in out_containers - assert 'A' not in out_containers - - -def test_state_data_dependence_with_non_contained_read_in_map(): - sdfg = SDFG('myprog') - N = dace.symbol('N') - sdfg.add_array('A', (N, ), dace.float32) - sdfg.add_array('tmp', (N, ), dace.float32) - sdfg.add_array('B', (N, ), dace.float32) - mystate = sdfg.add_state('mystate', is_start_block=True) - a_read = mystate.add_access('A') - tmp = mystate.add_access('tmp') - b_write = mystate.add_access('B') - m_entry, m_exit = mystate.add_map('my_map', {'i': '0:ceil(N/2)'}) - t1 = mystate.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 * 2.0') - t2 = mystate.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 - 1.0') - mystate.add_memlet_path(a_read, m_entry, t1, memlet=Memlet('A[i]'), dst_conn='i1') - mystate.add_memlet_path(t1, tmp, memlet=Memlet('tmp[i]'), src_conn='o1') - mystate.add_memlet_path(tmp, t2, memlet=Memlet('tmp[i+ceil(N/2)]'), dst_conn='i1') - mystate.add_memlet_path(t2, m_exit, b_write, memlet=Memlet('B[i]'), src_conn='o1') - - res = {} - Pipeline([StateDataDependence()]).apply_pass(sdfg, res) - state_data_deps = res[StateDataDependence.__name__][0][sdfg.states()[0]] - - assert len(state_data_deps[0]) == 2 - in_containers = [m.data for m in state_data_deps[0]] - assert 'A' in in_containers - assert 'tmp' in in_containers - assert 'B' not in in_containers - - assert len(state_data_deps[1]) == 2 - out_containers = [m.data for m in state_data_deps[1]] - assert 'B' in out_containers - assert 'tmp' in out_containers - assert 'A' not in out_containers - - -if __name__ == '__main__': - test_state_data_dependence_with_contained_read() - test_state_data_dependence_with_contained_read_in_map() - test_state_data_dependence_with_non_contained_read_in_map() From 074a990e6ff465add21589c7b0937ab8066123b0 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 7 Oct 2024 14:21:59 +0200 Subject: [PATCH 017/108] Add back missing file --- .../passes/analysis/loop_analysis.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 dace/transformation/passes/analysis/loop_analysis.py diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py new file mode 100644 index 0000000000..9dbfb62e9f --- /dev/null +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -0,0 +1,116 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from collections import defaultdict +from typing import Any, Dict, Optional, Set, Tuple + +import sympy + +from dace import SDFG, properties, symbolic, transformation +from dace.memlet import Memlet +from dace.sdfg.state import LoopRegion +from dace.subsets import Range, SubsetUnion +from dace.transformation import pass_pipeline as ppl + + +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = ast.unparse(node.value) + return self.generic_visit(node) + + +def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). + """ + end: Optional[symbolic.SymbolicType] = None + a = sympy.Wild('a') + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + itersym = symbolic.pystr_to_symbolic(loop.loop_variable) + match = condition.match(itersym < a) + if match: + end = match[a] - 1 + if end is None: + match = condition.match(itersym <= a) + if match: + end = match[a] + if end is None: + match = condition.match(itersym > a) + if match: + end = match[a] + 1 + if end is None: + match = condition.match(itersym >= a) + if match: + end = match[a] + return end + + +def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's init statement to identify the exact init assignment expression. + """ + init_stmt = loop.init_statement + if init_stmt is None: + return None + + init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] + assignments: Dict[str, str] = {} + for code in init_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's update statement to identify the exact update assignment expression. + """ + update_stmt = loop.update_statement + if update_stmt is None: + return None + + update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + assignments: Dict[str, str] = {} + for code in update_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + update_assignment = get_update_assignment(loop) + if update_assignment: + return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) + return None From cd8a258c3eb7ddde6f99c8c5ec8f497918e45895 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 7 Oct 2024 14:29:36 +0200 Subject: [PATCH 018/108] Adapt dead state elimination --- dace/frontend/python/interface.py | 2 +- dace/frontend/python/parser.py | 2 +- dace/sdfg/state.py | 7 + .../passes/analysis/analysis.py | 2 + .../passes/dead_state_elimination.py | 124 ++++++++++++------ dace/transformation/passes/fusion_inline.py | 1 + tests/passes/dead_code_elimination_test.py | 58 ++++++++ 7 files changed, 157 insertions(+), 39 deletions(-) diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 14164054d3..06bef0ba37 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -44,7 +44,7 @@ def program(f: F, recompile: bool = True, distributed_compilation: bool = False, constant_functions=False, - use_experimental_cfg_blocks=False, + use_experimental_cfg_blocks=True, **kwargs) -> Callable[..., parser.DaceProgram]: """ Entry point to a data-centric program. For methods and ``classmethod``s, use diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index d99be1265d..4a650325bd 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -153,7 +153,7 @@ def __init__(self, recompile: bool = True, distributed_compilation: bool = False, method: bool = False, - use_experimental_cfg_blocks: bool = False): + use_experimental_cfg_blocks: bool = True): from dace.codegen import compiled_sdfg # Avoid import loops self.f = f diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 4460802a62..db35e2b92b 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3257,6 +3257,13 @@ def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): branch.parent_graph = self.parent_graph branch.sdfg = self.sdfg + def remove_branch(self, branch: ControlFlowRegion): + filtered_branches = [] + for c, b in self._branches: + if b is not branch: + filtered_branches.append((c, b)) + self._branches = filtered_branches + def nodes(self) -> List['ControlFlowBlock']: return [node for _, node in self._branches if node is not None] diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 095319f807..45153c23b4 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -341,6 +341,7 @@ def apply_pass(self, top_sdfg: SDFG, @properties.make_properties +@transformation.single_level_sdfg_only class SymbolWriteScopes(ppl.Pass): """ For each symbol, create a dictionary mapping each writing interstate edge to that symbol to the set of interstate @@ -445,6 +446,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, @properties.make_properties +@transformation.single_level_sdfg_only class ScalarWriteShadowScopes(ppl.Pass): """ For each scalar or array of size 1, create a dictionary mapping writes to that data container to the set of reads diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 43239fe9af..ad6bc23b85 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -1,18 +1,19 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import collections import sympy as sp -from typing import Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Union from dace import SDFG, InterstateEdge, SDFGState, symbolic, properties from dace.properties import CodeBlock from dace.sdfg.graph import Edge -from dace.sdfg.validation import InvalidSDFGInterstateEdgeError +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.validation import InvalidSDFGInterstateEdgeError, InvalidSDFGNodeError from dace.transformation import pass_pipeline as ppl, transformation @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class DeadStateElimination(ppl.Pass): """ Removes all unreachable states (e.g., due to a branch that will never be taken) from an SDFG. @@ -25,7 +26,7 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If connectivity or any edges were changed, some more states might be dead - return modified & (ppl.Modifies.InterstateEdges | ppl.Modifies.States) + return modified & ppl.Modifies.CFG def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[InterstateEdge]]]]: """ @@ -38,42 +39,63 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters :param initial_symbols: If not None, sets values of initial symbols. :return: A set of the removed states, or None if nothing was changed. """ - # Mark dead states and remove them - dead_states, dead_edges, annotated = self.find_dead_states(sdfg, set_unconditional_edges=True) - - for e in dead_edges: - sdfg.remove_edge(e) - sdfg.remove_nodes_from(dead_states) + result: Set[Union[ControlFlowBlock, InterstateEdge]] = set() + removed_regions: Set[ControlFlowRegion] = set() + for cfg in list(sdfg.all_control_flow_regions()): + if cfg in removed_regions: + continue - result = dead_states | dead_edges + # Mark dead blocks and remove them + dead_blocks, dead_edges, annotated = self.find_dead_control_flow(cfg, set_unconditional_edges=True) + for e in dead_edges: + cfg.remove_edge(e) + for block in dead_blocks: + cfg.remove_node(block) + if isinstance(block, ControlFlowRegion): + removed_regions.add(block) + + region_result = dead_blocks | dead_edges + result |= region_result + + for node in cfg.nodes(): + if isinstance(node, ConditionalBlock): + dead_branches = self._find_dead_branches(node) + if len(dead_branches) < len(node.branches): + for _, b in dead_branches: + result.add(b) + node.remove_branch(b) + else: + result.add(node) + cfg.remove_node(block) if not annotated: return result or None else: return result or set() # Return an empty set if edges were annotated - def find_dead_states( + def find_dead_control_flow( self, - sdfg: SDFG, - set_unconditional_edges: bool = True) -> Tuple[Set[SDFGState], Set[Edge[InterstateEdge]], bool]: + cfg: ControlFlowRegion, + set_unconditional_edges: bool = True) -> Tuple[Set[ControlFlowBlock], Set[Edge[InterstateEdge]], bool]: """ - Finds "dead" (unreachable) states in an SDFG. A state is deemed unreachable if it is: + Finds "dead" (unreachable) control flow in a CFG. A block is deemed unreachable if it is: - * Unreachable from the starting state + * Unreachable from the starting block * Conditions leading to it will always evaluate to False - * There is another unconditional (always True) inter-state edge that leads to another state + * There is another unconditional (always True) inter-state edge that leads to another block - :param sdfg: The SDFG to traverse. + :param cfg: The CFG to traverse. :param set_unconditional_edges: If True, conditions of edges evaluated as unconditional are removed. - :return: A 3-tuple of (unreachable states, unreachable edges, were edges annotated). + :return: A 3-tuple of (unreachable blocks, unreachable edges, were edges annotated). """ - visited: Set[SDFGState] = set() + sdfg = cfg.sdfg if cfg.sdfg is not None else cfg + visited: Set[ControlFlowBlock] = set() dead_edges: Set[Edge[InterstateEdge]] = set() edges_annotated = False # Run a modified BFS where definitely False edges are not traversed, or if there is an - # unconditional edge the rest are not. The inverse of the visited states is the dead set. - queue = collections.deque([sdfg.start_state]) + # unconditional edge the rest are not. The inverse of the visited blocks is the dead set. + queue = collections.deque([cfg.start_block]) while len(queue) > 0: node = queue.popleft() if node in visited: @@ -82,13 +104,13 @@ def find_dead_states( # First, check for unconditional edges unconditional = None - for e in sdfg.out_edges(node): + for e in cfg.out_edges(node): # If an unconditional edge is found, ignore all other outgoing edges if self.is_definitely_taken(e.data, sdfg): # If more than one unconditional outgoing edge exist, fail with Invalid SDFG if unconditional is not None: - raise InvalidSDFGInterstateEdgeError('Multiple unconditional edges leave the same state', sdfg, - sdfg.edge_id(e)) + raise InvalidSDFGInterstateEdgeError('Multiple unconditional edges leave the same block', cfg, + cfg.edge_id(e)) unconditional = e if set_unconditional_edges and not e.data.is_unconditional(): # Annotate edge as unconditional @@ -101,7 +123,7 @@ def find_dead_states( continue if unconditional is not None: # Unconditional edge exists, skip traversal # Remove other (now never taken) edges from graph - for e in sdfg.out_edges(node): + for e in cfg.out_edges(node): if e is not unconditional: dead_edges.add(e) @@ -109,7 +131,7 @@ def find_dead_states( # End of unconditional check # Check outgoing edges normally - for e in sdfg.out_edges(node): + for e in cfg.out_edges(node): next_node = e.dst # Test for edges that definitely evaluate to False @@ -122,7 +144,31 @@ def find_dead_states( queue.append(next_node) # Dead states are states that are not live (i.e., visited) - return set(sdfg.nodes()) - visited, dead_edges, edges_annotated + return set(cfg.nodes()) - visited, dead_edges, edges_annotated + + def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, ControlFlowRegion]]: + dead_branches = [] + unconditional = None + for cond, branch in block.branches: + # If an unconditional branch is found, ignore all other branches + if cond.as_string.strip() == '1' or self._is_truthy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + # If more than one unconditional outgoing edge exist, fail with Invalid SDFG + if unconditional is not None: + raise InvalidSDFGNodeError('Multiple branches that are unconditionally true in conditional block', + block.parent_graph, block.block_id) + unconditional = branch + if unconditional is not None: + # Remove other (now never taken) branches + for cond, branch in block.branches: + if branch is not unconditional: + dead_branches.append([cond, branch]) + else: + # Check if any branches are certainly never taken. + for cond, branch in block.branches: + if self._is_falsy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + dead_branches.append([cond, branch]) + + return dead_branches def report(self, pass_retval: Set[Union[SDFGState, Edge[InterstateEdge]]]) -> str: if pass_retval is not None and not pass_retval: @@ -137,13 +183,15 @@ def is_definitely_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return True # Evaluate condition - scond = edge.condition_sympy() - if scond == True or scond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False): + return self._is_truthy(edge.condition_sympy(), sdfg) + + def _is_truthy(self, cond: sp.Basic, sdfg: SDFG) -> bool: + if cond == True or cond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False): return True # Evaluate non-optional arrays - scond = symbolic.evaluate_optional_arrays(scond, sdfg) - if scond == True: + cond = symbolic.evaluate_optional_arrays(cond, sdfg) + if cond == True: return True # Indeterminate or False condition @@ -155,13 +203,15 @@ def is_definitely_not_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return False # Evaluate condition - scond = edge.condition_sympy() - if scond == False or scond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False): + return self._is_falsy(edge.condition_sympy(), sdfg) + + def _is_falsy(self, cond: sp.Basic, sdfg: SDFG) -> bool: + if cond == False or cond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False): return True # Evaluate non-optional arrays - scond = symbolic.evaluate_optional_arrays(scond, sdfg) - if scond == False: + cond = symbolic.evaluate_optional_arrays(cond, sdfg) + if cond == False: return True # Indeterminate or True condition diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 9a97afb569..15d1eeca5f 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -52,6 +52,7 @@ def report(self, pass_retval: int) -> str: @dataclass(unsafe_hash=True) @properties.make_properties +@experimental_cfg_block_compatible class InlineSDFGs(ppl.Pass): """ Inlines all possible nested SDFGs (and sub-SDFGs). diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index f8920b0538..ba7bd91f30 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -4,6 +4,8 @@ import numpy as np import pytest import dace +from dace.properties import CodeBlock +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.dead_state_elimination import DeadStateElimination from dace.transformation.passes.dead_dataflow_elimination import DeadDataflowElimination @@ -45,6 +47,60 @@ def test_dse_unconditional(): assert set(sdfg.states()) == {s, s2, e} +def test_dse_inside_loop(): + sdfg = dace.SDFG('dse_inside_loop') + sdfg.add_symbol('a', dace.int32) + loop = LoopRegion('loop', 'i < 10', 'i', 'i = 0', 'i = i + 1') + start = sdfg.add_state(is_start_block=True) + sdfg.add_node(loop) + end = sdfg.add_state() + sdfg.add_edge(start, loop, dace.InterstateEdge()) + sdfg.add_edge(loop, end, dace.InterstateEdge()) + s = loop.add_state(is_start_block=True) + s1 = loop.add_state() + s2 = loop.add_state() + s3 = loop.add_state() + e = loop.add_state() + loop.add_edge(s, s1, dace.InterstateEdge('a > 0')) + loop.add_edge(s, s2, dace.InterstateEdge('a >= a')) # Always True + loop.add_edge(s, s3, dace.InterstateEdge('a < 0')) + loop.add_edge(s1, e, dace.InterstateEdge()) + loop.add_edge(s2, e, dace.InterstateEdge()) + loop.add_edge(s3, e, dace.InterstateEdge()) + + DeadStateElimination().apply_pass(sdfg, {}) + assert set(sdfg.states()) == {start, s, s2, e, end} + + +def test_dse_inside_loop_conditional(): + sdfg = dace.SDFG('dse_inside_loop') + sdfg.add_symbol('a', dace.int32) + loop = LoopRegion('loop', 'i < 10', 'i', 'i = 0', 'i = i + 1') + start = sdfg.add_state(is_start_block=True) + sdfg.add_node(loop) + end = sdfg.add_state() + sdfg.add_edge(start, loop, dace.InterstateEdge()) + sdfg.add_edge(loop, end, dace.InterstateEdge()) + s = loop.add_state(is_start_block=True) + cond_block = ConditionalBlock('cond', sdfg, loop) + loop.add_node(cond_block) + b1 = ControlFlowRegion('b1', sdfg) + b1.add_state() + cond_block.add_branch(CodeBlock('a > 0'), b1) + b2 = ControlFlowRegion('b2', sdfg) + s2 = b2.add_state() + cond_block.add_branch(CodeBlock('a >= a'), b2) + b3 = ControlFlowRegion('b3', sdfg) + b3.add_state() + cond_block.add_branch(CodeBlock('a < 0'), b3) + e = loop.add_state() + loop.add_edge(s, cond_block, dace.InterstateEdge()) + loop.add_edge(cond_block, e, dace.InterstateEdge()) + + DeadStateElimination().apply_pass(sdfg, {}) + assert set(sdfg.states()) == {start, s, s2, e, end} + + def test_dde_simple(): @dace.program @@ -294,6 +350,8 @@ def test_dce_add_type_hint_of_variable(): if __name__ == '__main__': test_dse_simple() test_dse_unconditional() + test_dse_inside_loop() + test_dse_inside_loop_conditional() test_dde_simple() test_dde_libnode() test_dde_access_node_in_scope(False) From f77655fab3b9e574ff351b7257b8f437af09774f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 7 Oct 2024 14:38:21 +0200 Subject: [PATCH 019/108] Adapt DeadDataflowElimination --- dace/transformation/pass_pipeline.py | 4 ++-- dace/transformation/passes/dead_dataflow_elimination.py | 9 +++++---- tests/passes/dead_code_elimination_test.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 0da8a96165..8c748ef8d5 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -511,8 +511,8 @@ def apply_subpass(self, sdfg: SDFG, p: Pass, state: Dict[str, Any]) -> Optional[ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Any]]: if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(self, '__experimental_cfg_block_compatible__') or - self.__experimental_cfg_block_compatible__ == False): + if (type(self) != Pipeline and (not hasattr(self, '__experimental_cfg_block_compatible__') or + self.__experimental_cfg_block_compatible__ == False)): warnings.warn('Pipeline ' + self.__class__.__name__ + ' is being skipped due to incompatibility with ' + 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 856924abd2..c95aee665b 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -19,7 +19,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class DeadDataflowElimination(ppl.Pass): """ Removes unused computations from SDFG states. @@ -59,13 +59,14 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D # Depends on the following analysis passes: # * State reachability # * Read/write access sets per state - reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results['StateReachability'][sdfg.cfg_id] - access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results['AccessSets'][sdfg.cfg_id] + reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results[ap.StateReachability.__name__][sdfg.cfg_id] + access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results[ap.AccessSets.__name__][sdfg.cfg_id] result: Dict[SDFGState, Set[str]] = defaultdict(set) # Traverse SDFG backwards try: - state_order = list(cfg.blockorder_topological_sort(sdfg)) + state_order: List[SDFGState] = list(cfg.blockorder_topological_sort(sdfg, recursive=True, + ignore_nonstate_blocks=True)) except KeyError: return None for state in reversed(state_order): diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index ba7bd91f30..14d380c463 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -265,7 +265,7 @@ def dce_tester(a: dace.float64[20], b: dace.float64[20]): sdfg = dce_tester.to_sdfg(simplify=False) result = Pipeline([DeadDataflowElimination(), DeadStateElimination()]).apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.number_of_nodes() <= 5 + assert sdfg.number_of_nodes() <= 6 # Check that arrays were removed assert all('c' not in [n.data for n in state.data_nodes()] for state in sdfg.nodes()) From 49475ec67b71a1ab3e6a347b5874d8049326f238 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 8 Oct 2024 10:50:04 +0200 Subject: [PATCH 020/108] Fixes --- dace/transformation/interstate/loop_lifting.py | 12 +++++------- tests/sdfg/conditional_region_test.py | 2 +- .../transformations/interstate/loop_lifting_test.py | 12 +++++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 604aa74d16..0c64c1548a 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -75,9 +75,6 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): graph.add_edge(loop, after, InterstateEdge(assignments=exit_edge.data.assignments)) loop.add_node(first_state, is_start_block=True) - for n in full_body: - if n is not first_state: - loop.add_node(n) added = set() for e in graph.all_edges(*full_body): if e.src in full_body and e.dst in full_body: @@ -85,11 +82,12 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): added.add(e) if e is incr_edge: if left_over_incr_assignments != {}: - loop.add_edge(e.src, loop.add_state(label + '_tail'), - InterstateEdge(assignments=left_over_incr_assignments)) + dst = loop.add_state(label + '_tail') if not inverted else e.dst + loop.add_edge(e.src, dst, InterstateEdge(assignments=left_over_incr_assignments)) elif e is cond_edge: - e.data.condition = properties.CodeBlock('1') - loop.add_edge(e.src, e.dst, e.data) + if not inverted: + e.data.condition = properties.CodeBlock('1') + loop.add_edge(e.src, e.dst, e.data) else: loop.add_edge(e.src, e.dst, e.data) diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index a6a46bc568..0be40f43d3 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -19,7 +19,7 @@ def test_cond_region_if(): sdfg.add_edge(state0, if1, InterstateEdge()) if_body = ControlFlowRegion('if_body', sdfg=sdfg) - if1.add_branch((CodeBlock('i == 1'), if_body)) + if1.add_branch(CodeBlock('i == 1'), if_body) state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index 843209794f..21de6bc884 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -55,7 +55,8 @@ def test_lift_regular_for_loop(): @pytest.mark.parametrize('increment_before_condition', (True, False)) def test_lift_loop_llvm_canonical(increment_before_condition): - sdfg = dace.SDFG('llvm_canonical') + addendum = '_incr_before_cond' if increment_before_condition else '' + sdfg = dace.SDFG('llvm_canonical' + addendum) N = dace.symbol('N') sdfg.add_symbol('i', dace.int32) sdfg.add_symbol('j', dace.int32) @@ -77,10 +78,11 @@ def test_lift_loop_llvm_canonical(increment_before_condition): if increment_before_condition: sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) else: sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) - sdfg.add_edge(latch, body, InterstateEdge(condition='i < N', assignments={'i': 'i + 2'})) - sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N - 2', assignments={'i': 'i + 2'})) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N - 2', assignments={'k': 2})) sdfg.add_edge(loopexit, exitstate, InterstateEdge()) a_access = body.add_access('A') @@ -128,8 +130,8 @@ def test_lift_loop_llvm_canonical_while(): sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) sdfg.add_edge(preheader, body, InterstateEdge(assignments={'k': 0})) sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) - sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) - sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N - 2')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N - 2', assignments={'k': 2})) sdfg.add_edge(loopexit, exitstate, InterstateEdge()) i_init_write = entry.add_access('i') From 2a74901599908ceab8fd80392562d2719833a727 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 8 Oct 2024 12:39:38 +0200 Subject: [PATCH 021/108] Bugfix --- dace/codegen/targets/framecode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index f7b8338269..e1e5b5a1da 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -923,7 +923,8 @@ def generate_code(self, if isinstance(init_assignment, astutils.ast.Assign): init_assignment = init_assignment.value if not cfr.loop_variable in interstate_symbols: - interstate_symbols[cfr.loop_variable] = infer_expr_type(astutils.unparse(init_assignment)) + interstate_symbols[cfr.loop_variable] = infer_expr_type(astutils.unparse(init_assignment), + global_symbols) if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] From 8127900322bb89957be42c62c75b034e53dd9e0f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 8 Oct 2024 14:05:55 +0200 Subject: [PATCH 022/108] Adapt trivial loop elimination and state elimination --- .../interstate/state_elimination.py | 23 +++--- .../interstate/trivial_loop_elimination.py | 72 ++++++++++--------- 2 files changed, 51 insertions(+), 44 deletions(-) diff --git a/dace/transformation/interstate/state_elimination.py b/dace/transformation/interstate/state_elimination.py index 2640e30ccc..6ffe9fa468 100644 --- a/dace/transformation/interstate/state_elimination.py +++ b/dace/transformation/interstate/state_elimination.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ State elimination transformations """ import networkx as nx @@ -8,6 +8,7 @@ from dace.properties import CodeBlock from dace.sdfg import nodes, SDFG, SDFGState from dace.sdfg import utils as sdutil +from dace.sdfg.sdfg import InterstateEdge from dace.sdfg.state import ControlFlowRegion from dace.transformation import transformation @@ -222,7 +223,7 @@ def _str_repl(s, d): symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m)) -def _alias_assignments(sdfg, edge): +def _alias_assignments(sdfg: SDFG, edge: InterstateEdge): assignments_to_consider = {} for var, assign in edge.assignments.items(): if assign in sdfg.symbols or (assign in sdfg.arrays and isinstance(sdfg.arrays[assign], dt.Scalar)): @@ -230,7 +231,7 @@ def _alias_assignments(sdfg, edge): return assignments_to_consider -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class SymbolAliasPromotion(transformation.MultiStateTransformation): """ SymbolAliasPromotion moves inter-state assignments that create symbolic @@ -298,12 +299,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): fstate = self.first_state sstate = self.second_state - edge = sdfg.edges_between(fstate, sstate)[0].data - in_edge = sdfg.in_edges(fstate)[0].data + edge = graph.edges_between(fstate, sstate)[0].data + in_edge = graph.in_edges(fstate)[0].data to_consider = _alias_assignments(sdfg, edge) @@ -335,7 +336,7 @@ def apply(self, _, sdfg): in_edge.assignments[k] = v -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class HoistState(transformation.SingleStateTransformation): """ Move a state out of a nested SDFG """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @@ -359,6 +360,8 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): return False if nsdfg.sdfg.start_state.number_of_nodes() != 0: return False + if any([not isinstance(x, SDFGState) for x in nsdfg.sdfg.nodes()]): + return False # Must have at least two states with a hoistable source state if nsdfg.sdfg.number_of_nodes() < 2: @@ -428,8 +431,8 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): def apply(self, state: SDFGState, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg - new_state = sdfg.add_state_before(state) - isedge = sdfg.edges_between(new_state, state)[0] + new_state = state.parent_graph.add_state_before(state) + isedge = state.parent_graph.edges_between(new_state, state)[0] # Find relevant symbol and data descriptor mapping mapping: Dict[str, str] = {} @@ -438,7 +441,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): mapping.update({k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.out_connectors}) # Get internal state and interstate edge - source_state = nsdfg.sdfg.start_state + source_state: SDFGState = nsdfg.sdfg.start_state nisedge = nsdfg.sdfg.out_edges(source_state)[0] # Add state contents (nodes) diff --git a/dace/transformation/interstate/trivial_loop_elimination.py b/dace/transformation/interstate/trivial_loop_elimination.py index 411d9ff07d..981d3833b6 100644 --- a/dace/transformation/interstate/trivial_loop_elimination.py +++ b/dace/transformation/interstate/trivial_loop_elimination.py @@ -1,33 +1,39 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Eliminates trivial loop """ from dace import sdfg as sd -from dace.properties import CodeBlock +from dace.sdfg import utils as sdutil +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.transformation import helpers, transformation -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) +from dace.transformation.passes.analysis import loop_analysis -@transformation.single_level_sdfg_only -class TrivialLoopElimination(DetectLoop, transformation.MultiStateTransformation): +@transformation.experimental_cfg_block_compatible +class TrivialLoopElimination(transformation.MultiStateTransformation): """ Eliminates loops with a single loop iteration. """ - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + loop = transformation.PatternNode(LoopRegion) + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # Obtain iteration variable, range, and stride - loop_info = self.loop_information() - if not loop_info: + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # Check if this is a for-loop with known range. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + stride = loop_analysis.get_loop_stride(self.loop) + if start is None or end is None or stride is None: return False - _, (start, end, step), _ = loop_info + # Check if this is a trivial loop. try: - if step > 0 and start + step < end + 1: + if stride > 0 and start + stride < end + 1: return False - if step < 0 and start + step > end - 1: + if stride < 0 and start + stride > end - 1: return False except: # if the relation can't be determined it's not a trivial loop @@ -35,28 +41,26 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg: sd.SDFG): - # Obtain loop information + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Obtain iteration variable, range and stride - itervar, (start, end, step), (_, body_end) = self.loop_information() - states = self.loop_body() - - for state in states: - state.replace(itervar, start) - - # Remove loop - sdfg.remove_edge(self.loop_increment_edge()) + itervar = self.loop.loop_variable + start = loop_analysis.get_init_assignment(self.loop) - init_edge = self.loop_init_edge() - init_edge.data.assignments = {} - sdfg.add_edge(init_edge.src, self.loop_begin, init_edge.data) - sdfg.remove_edge(init_edge) + self.loop.replace(itervar, start) - exit_edge = self.loop_exit_edge() - exit_edge.data.condition = CodeBlock("1") - sdfg.add_edge(body_end, exit_edge.dst, exit_edge.data) - sdfg.remove_edge(exit_edge) + # Add the loop contents to the parent graph. + graph.add_node(self.loop.start_block) + for e in graph.in_edges(self.loop): + graph.add_edge(e.src, self.loop.start_block, e.data) + sink = graph.add_state(self.loop.label + '_sink') + for n in self.loop.sink_nodes(): + graph.add_edge(n, sink, InterstateEdge()) + for e in graph.out_edges(self.loop): + graph.add_edge(sink, e.dst, e.data) + for e in self.loop.edges(): + graph.add_edge(e.src, e.dst, e.data) - sdfg.remove_nodes_from(self.loop_meta_states()) + # Remove loop and if necessary also the loop variable. + graph.remove_node(self.loop) if itervar in sdfg.symbols and helpers.is_symbol_unused(sdfg, itervar): sdfg.remove_symbol(itervar) From 04f41c2edcea55467cdcd1b79c59afd014582beb Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 8 Oct 2024 14:11:42 +0200 Subject: [PATCH 023/108] gdapt passes: - StateFusionExtended - ArrayElimination - OptionalArrayInference --- dace/sdfg/state.py | 2 - dace/sdfg/utils.py | 30 +++++++------ .../state_fusion_with_happens_before.py | 41 +++++++++--------- .../passes/array_elimination.py | 16 +++---- .../passes/dead_state_elimination.py | 15 +++---- dace/transformation/passes/optional_arrays.py | 42 ++++++++++++------- 6 files changed, 81 insertions(+), 65 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index db35e2b92b..d058577a3e 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -11,8 +11,6 @@ from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload) -import sympy - import dace from dace.frontend.python import astutils import dace.serialize diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 5b9ce1a431..ed179df0cf 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,8 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import ConditionalBlock, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion +from dace.sdfg.state import (ConditionalBlock, ControlFlowBlock, SDFGState, StateSubgraphView, LoopRegion, + ControlFlowRegion) from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs @@ -1585,41 +1586,44 @@ def is_fpga_kernel(sdfg, state): return at_least_one_fpga_array +CFBlockDictT = Dict[ControlFlowBlock, ControlFlowBlock] + + def postdominators( - sdfg: SDFG, + cfg: ControlFlowRegion, return_alldoms: bool = False -) -> Optional[Union[Dict[SDFGState, SDFGState], Tuple[Dict[SDFGState, SDFGState], Dict[SDFGState, Set[SDFGState]]]]]: +) -> Optional[Union[CFBlockDictT, Tuple[CFBlockDictT, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]]]: """ - Return the immediate postdominators of an SDFG. This may require creating new nodes and removing them, which - happens in-place on the SDFG. + Return the immediate postdominators of a CFG. This may require creating new nodes and removing them, which + happens in-place on the CFG. - :param sdfg: The SDFG to generate the postdominators from. + :param cfg: The CFG to generate the postdominators from. :param return_alldoms: If True, returns the "all postdominators" dictionary as well. :return: Immediate postdominators, or a 2-tuple of (ipostdom, allpostdoms) if ``return_alldoms`` is True. """ - from dace.sdfg.analysis import cfg + from dace.sdfg.analysis import cfg as cfg_analysis # Get immediate post-dominators - sink_nodes = sdfg.sink_nodes() + sink_nodes = cfg.sink_nodes() if len(sink_nodes) > 1: - sink = sdfg.add_state() + sink = cfg.add_state() for snode in sink_nodes: - sdfg.add_edge(snode, sink, dace.InterstateEdge()) + cfg.add_edge(snode, sink, dace.InterstateEdge()) elif len(sink_nodes) == 0: return None else: sink = sink_nodes[0] - ipostdom: Dict[SDFGState, SDFGState] = nx.immediate_dominators(sdfg._nx.reverse(), sink) + ipostdom: CFBlockDictT = nx.immediate_dominators(cfg._nx.reverse(), sink) if return_alldoms: - allpostdoms = cfg.all_dominators(sdfg, ipostdom) + allpostdoms = cfg_analysis.all_dominators(cfg, ipostdom) retval = (ipostdom, allpostdoms) else: retval = ipostdom # If a new sink was added for post-dominator computation, remove it if len(sink_nodes) > 1: - sdfg.remove_node(sink) + cfg.remove_node(sink) return retval diff --git a/dace/transformation/interstate/state_fusion_with_happens_before.py b/dace/transformation/interstate/state_fusion_with_happens_before.py index 408f5a76f2..ae2007e59f 100644 --- a/dace/transformation/interstate/state_fusion_with_happens_before.py +++ b/dace/transformation/interstate/state_fusion_with_happens_before.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ State fusion transformation """ from typing import Dict, List, Set @@ -9,7 +9,7 @@ from dace.config import Config from dace.sdfg import nodes from dace.sdfg import utils as sdutil -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.transformation import transformation @@ -31,7 +31,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class StateFusionExtended(transformation.MultiStateTransformation): """ Implements the state-fusion transformation extended to fuse states with RAW and WAW dependencies. An empty memlet is used to represent a dependency between two subgraphs with RAW and WAW dependencies. @@ -461,33 +461,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg): first_state: SDFGState = self.first_state second_state: SDFGState = self.second_state # Remove interstate edge(s) - edges = sdfg.edges_between(first_state, second_state) + edges = graph.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: - for src, dst, other_data in sdfg.in_edges(first_state): + for src, dst, other_data in graph.in_edges(first_state): other_data.assignments.update(edge.data.assignments) - sdfg.remove_edge(edge) + graph.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): - sdutil.change_edge_dest(sdfg, first_state, second_state) - sdfg.remove_node(first_state) - if sdfg.start_state == first_state: - sdfg.start_state = sdfg.node_id(second_state) + sdutil.change_edge_dest(graph, first_state, second_state) + graph.remove_node(first_state) + if graph.start_block == first_state: + graph.start_block = graph.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): - sdutil.change_edge_src(sdfg, second_state, first_state) - sdutil.change_edge_dest(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + sdutil.change_edge_dest(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) return # Normal case: both states are not empty @@ -495,7 +495,6 @@ def apply(self, _, sdfg): # Find source/sink (data) nodes first_input = [node for node in first_state.source_nodes() if isinstance(node, nodes.AccessNode)] first_output = [node for node in first_state.sink_nodes() if isinstance(node, nodes.AccessNode)] - second_input = [node for node in second_state.source_nodes() if isinstance(node, nodes.AccessNode)] top2 = top_level_nodes(second_state) @@ -585,7 +584,7 @@ def apply(self, _, sdfg): merged_nodes.add(n) # Redirect edges and remove second state - sdutil.change_edge_src(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index 46411478d5..6681ed6da0 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -13,7 +13,7 @@ @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class ArrayElimination(ppl.Pass): """ Merges and removes arrays and their corresponding accesses. This includes redundant array copies, unnecessary views, @@ -48,7 +48,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S # Traverse SDFG backwards try: - state_order = list(cfg.blockorder_topological_sort(sdfg)) + state_order = list(cfg.blockorder_topological_sort(sdfg, recursive=True, ignore_nonstate_blocks=True)) except KeyError: return None for state in reversed(state_order): @@ -132,14 +132,14 @@ def remove_redundant_views(self, sdfg: SDFG, state: SDFGState, access_nodes: Dic """ removed_nodes: Set[nodes.AccessNode] = set() xforms = [RemoveSliceView()] - state_id = sdfg.node_id(state) + state_id = state.block_id for nodeset in access_nodes.values(): for anode in list(nodeset): for xform in xforms: # Quick path to setup match candidate = {type(xform).view: anode} - xform.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + xform.setup_match(sdfg, state.parent_graph.cfg_id, state_id, candidate, 0, override=True) # Try to apply if xform.can_be_applied(state, 0, sdfg): @@ -154,7 +154,7 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: Removes access nodes that represent redundant copies and/or views. """ removed_nodes: Set[nodes.AccessNode] = set() - state_id = sdfg.node_id(state) + state_id = state.block_id # Transformations that remove the first access node xforms_first: List[SingleStateTransformation] = [RedundantWriteSlice(), UnsqueezeViewRemove(), RedundantArray()] @@ -184,7 +184,8 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: for xform in xforms_first: # Quick path to setup match candidate = {type(xform).in_array: anode, type(xform).out_array: succ} - xform.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + xform.setup_match(sdfg, state.parent_graph.cfg_id, state_id, candidate, 0, + override=True) # Try to apply if xform.can_be_applied(state, 0, sdfg): @@ -204,7 +205,8 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: for xform in xforms_second: # Quick path to setup match candidate = {type(xform).in_array: pred, type(xform).out_array: anode} - xform.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + xform.setup_match(sdfg, state.parent_graph.cfg_id, state_id, candidate, 0, + override=True) # Try to apply if xform.can_be_applied(state, 0, sdfg): diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index ad6bc23b85..53e9b4f466 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -149,14 +149,15 @@ def find_dead_control_flow( def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, ControlFlowRegion]]: dead_branches = [] unconditional = None - for cond, branch in block.branches: - # If an unconditional branch is found, ignore all other branches + for i, (cond, branch) in enumerate(block.branches): + if cond is None: + if not i == len(block.branches) - 1: + raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch') + break + # If an unconditional branch is found, ignore all other branches that follow this one. if cond.as_string.strip() == '1' or self._is_truthy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): - # If more than one unconditional outgoing edge exist, fail with Invalid SDFG - if unconditional is not None: - raise InvalidSDFGNodeError('Multiple branches that are unconditionally true in conditional block', - block.parent_graph, block.block_id) unconditional = branch + break if unconditional is not None: # Remove other (now never taken) branches for cond, branch in block.branches: @@ -165,7 +166,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, else: # Check if any branches are certainly never taken. for cond, branch in block.branches: - if self._is_falsy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + if cond is not None and self._is_falsy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): dead_branches.append([cond, branch]) return dead_branches diff --git a/dace/transformation/passes/optional_arrays.py b/dace/transformation/passes/optional_arrays.py index f52ee5af43..366231d1f1 100644 --- a/dace/transformation/passes/optional_arrays.py +++ b/dace/transformation/passes/optional_arrays.py @@ -2,14 +2,15 @@ from typing import Dict, Iterator, Optional, Set, Tuple -from dace import SDFG, SDFGState, data, properties +from dace import SDFG, data, properties from dace.sdfg import nodes from dace.sdfg import utils as sdutil +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion, SDFGState from dace.transformation import pass_pipeline as ppl, transformation @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class OptionalArrayInference(ppl.Pass): """ Infers the ``optional`` property of arrays, i.e., if they can be given None, throughout the SDFG and all nested @@ -63,7 +64,7 @@ def apply_pass(self, arr.optional = parent_arrays[aname] # Change unconditionally-accessed arrays to non-optional - for state in self.traverse_unconditional_states(sdfg): + for state in self.traverse_unconditional_blocks(sdfg, recursive=True): for anode in state.data_nodes(): desc = anode.desc(sdfg) if isinstance(desc, data.Array) and desc.optional is None: @@ -71,7 +72,7 @@ def apply_pass(self, result.add((cfg_id, anode.data)) # Propagate information to nested SDFGs - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): # Create information about parent arrays @@ -96,27 +97,38 @@ def apply_pass(self, return result or None - def traverse_unconditional_states(self, sdfg: SDFG) -> Iterator[SDFGState]: + def traverse_unconditional_blocks(self, cfg: ControlFlowRegion, + recursive: bool = True, + produce_nonstate: bool = False) -> Iterator[ControlFlowBlock]: """ - Traverse SDFG and keep track of whether the state is executed unconditionally. + Traverse CFG and keep track of whether the block is executed unconditionally. """ - ipostdom = sdutil.postdominators(sdfg) - curstate = sdfg.start_state - out_degree = sdfg.out_degree(curstate) + ipostdom = sdutil.postdominators(cfg) + curblock = cfg.start_block + out_degree = cfg.out_degree(curblock) while out_degree > 0: - yield curstate + if produce_nonstate: + yield curblock + elif isinstance(curblock, SDFGState): + yield curblock + + if recursive and isinstance(curblock, ControlFlowRegion) and not isinstance(curblock, LoopRegion): + yield from self.traverse_unconditional_blocks(curblock, recursive, produce_nonstate) if out_degree == 1: # Unconditional, continue to next state - curstate = sdfg.successors(curstate)[0] + curblock = cfg.successors(curblock)[0] elif out_degree > 1: # Conditional branch # Conditional code follows, use immediate post-dominator for next unconditional state - curstate = ipostdom[curstate] + curblock = ipostdom[curblock] # Compute new out degree - if curstate in sdfg.nodes(): - out_degree = sdfg.out_degree(curstate) + if curblock in cfg.nodes(): + out_degree = cfg.out_degree(curblock) else: out_degree = 0 # Yield final state - yield curstate + if produce_nonstate: + yield curblock + elif isinstance(curblock, SDFGState): + yield curblock def report(self, pass_retval: Set[Tuple[int, str]]) -> str: return f'Inferred {len(pass_retval)} optional arrays.' From 852719dec8e1009d361106bdb38e39aef2c15bb8 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 9 Oct 2024 09:40:09 +0200 Subject: [PATCH 024/108] Adapt loop to map --- dace/codegen/control_flow.py | 8 +- dace/sdfg/state.py | 8 +- dace/sdfg/validation.py | 25 +- dace/transformation/interstate/loop_to_map.py | 470 ++++++------------ tests/transformations/loop_to_map_test.py | 15 +- 5 files changed, 203 insertions(+), 323 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index f5559984e7..d7272a214f 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -536,10 +536,14 @@ def as_cpp(self, codegen, symbols) -> str: expr = '' if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + lsyms = {} + lsyms.update(symbols) + if codegen.dispatcher.defined_vars.has(self.loop.loop_variable) and not self.loop.loop_variable in lsyms: + lsyms[self.loop.loop_variable] = codegen.dispatcher.defined_vars.get(self.loop.loop_variable)[1] + init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=lsyms) init = init.strip(';') - update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=lsyms) update = update.strip(';') if self.loop.inverted: diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index d058577a3e..6d2a219aab 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -20,11 +20,11 @@ from dace import serialize from dace import subsets as sbs from dace import symbolic -from dace.properties import (CodeBlock, DebugInfoProperty, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, - CodeProperty, make_properties) +from dace.properties import (CodeBlock, DebugInfoProperty, DictProperty, EnumProperty, Property, SubsetProperty, + SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import (MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, - OrderedDiGraph, Edge, generate_element_id) +from dace.sdfg.graph import (MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge, + generate_element_id) from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index e75099276f..af0e18dbda 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -34,7 +34,7 @@ def validate_control_flow_region(sdfg: 'SDFG', symbols: dict, references: Set[int] = None, **context: bool): - from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock + from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock, LoopRegion from dace.sdfg.scope import is_in_scope if len(region.source_nodes()) > 1 and region.start_block is None: @@ -70,8 +70,15 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(edge.src, SDFGState): validate_state(edge.src, region.node_id(edge.src), sdfg, symbols, initialized_transients, references, **context) + elif isinstance(edge.src, ConditionalBlock): + for _, r in edge.src.branches: + if r is not None: + validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.src, ControlFlowRegion): - validate_control_flow_region(sdfg, edge.src, initialized_transients, symbols, references, **context) + lsyms = copy.deepcopy(symbols) + if isinstance(edge.src, LoopRegion) and not edge.src.loop_variable in lsyms: + lsyms[edge.src.loop_variable] = None + validate_control_flow_region(sdfg, edge.src, initialized_transients, lsyms, references, **context) ########################################## # Edge @@ -133,7 +140,10 @@ def validate_control_flow_region(sdfg: 'SDFG', if r is not None: validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.dst, ControlFlowRegion): - validate_control_flow_region(sdfg, edge.dst, initialized_transients, symbols, references, **context) + lsyms = copy.deepcopy(symbols) + if isinstance(edge.dst, LoopRegion) and not edge.dst.loop_variable in lsyms: + lsyms[edge.dst.loop_variable] = None + validate_control_flow_region(sdfg, edge.dst, initialized_transients, lsyms, references, **context) # End of block DFS # If there is only one block, the DFS will miss it @@ -141,8 +151,15 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(start_block, SDFGState): validate_state(start_block, region.node_id(start_block), sdfg, symbols, initialized_transients, references, **context) + elif isinstance(start_block, ConditionalBlock): + for _, r in start_block.branches: + if r is not None: + validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(start_block, ControlFlowRegion): - validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context) + lsyms = copy.deepcopy(symbols) + if isinstance(start_block, LoopRegion) and not start_block.loop_variable in lsyms: + lsyms[start_block.loop_variable] = None + validate_control_flow_region(sdfg, start_block, initialized_transients, lsyms, references, **context) # Validate all inter-state edges (including self-loops not found by DFS) for eid, edge in enumerate(region.edges()): diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 39410f2547..245fff69ab 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -1,23 +1,20 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop to map transformation """ from collections import defaultdict import copy -import itertools import sympy as sp -import networkx as nx -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Set -from dace import data as dt, dtypes, memlet, nodes, registry, sdfg as sd, symbolic, subsets -from dace.properties import Property, make_properties, CodeBlock +from dace import data as dt, memlet, nodes, sdfg as sd, symbolic, subsets, properties from dace.sdfg import graph as gr, nodes -from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil -from dace.sdfg.analysis import cfg -from dace.frontend.python.astutils import ASTFindReplace -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.state import ControlFlowRegion, LoopRegion import dace.transformation.helpers as helpers from dace.transformation import transformation as xf +from dace.transformation.passes.analysis import loop_analysis def _check_range(subset, a, itersym, b, step): @@ -74,97 +71,71 @@ def _sanitize_by_index(indices: Set[int], subset: subsets.Subset) -> subsets.Ran return type(subset)([t for i, t in enumerate(subset) if i in indices]) -@make_properties -@xf.single_level_sdfg_only -class LoopToMap(DetectLoop, xf.MultiStateTransformation): - """Convert a control flow loop into a dataflow map. Currently only supports - the simple case where there is no overlap between inputs and outputs in - the body of the loop, and where the loop body only consists of a single - state. +@properties.make_properties +@xf.experimental_cfg_block_compatible +class LoopToMap(xf.MultiStateTransformation): + """ + Convert a control flow loop into a dataflow map. Currently only supports the simple case where there is no overlap + between inputs and outputs in the body of the loop, and where the loop body only consists of a single state. """ - itervar = Property( - dtype=str, - allow_none=True, - default=None, - desc='The name of the iteration variable (optional).', - ) - - def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False - - begin = self.loop_begin + loop = xf.PatternNode(LoopRegion) - # Guard state should not contain any dataflow - if expr_index <= 1: - guard = self.loop_guard - if len(guard.nodes()) != 0: - return False + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # If loop cannot be detected, fail - found = self.loop_information(itervar=self.itervar) - if not found: + def can_be_applied(self, graph, expr_index, sdfg, permissive = False): + # If loop information cannot be determined, fail. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) + itervar = self.loop.loop_variable + if start is None or end is None or step is None or itervar is None: return False - itervar, (start, end, step), (_, body_end) = found - - # We cannot handle symbols read from data containers unless they are - # scalar + # We cannot handle symbols read from data containers unless they are scalar. for expr in (start, end, step): if symbolic.contains_sympy_functions(expr): return False - in_order_states = list(cfg.blockorder_topological_sort(sdfg)) - loop_begin_idx = in_order_states.index(begin) - loop_end_idx = in_order_states.index(body_end) - - if loop_end_idx < loop_begin_idx: # Malformed loop - return False - - # Find all loop-body states - states: List[SDFGState] = self.loop_body() - - assert (body_end in states) - + loop_states = set(self.loop.all_states()) write_set: Set[str] = set() - for state in states: + for state in loop_states: _, wset = state.read_and_write_sets() write_set |= wset # Collect symbol reads and writes from inter-state assignments symbols_that_may_be_used: Set[str] = {itervar} used_before_assignment: Set[str] = set() - for state in states: - for e in sdfg.out_edges(state): - # Collect read-before-assigned symbols (this works because the states are always in order, - # see above call to `blockorder_topological_sort`) - read_symbols = e.data.read_symbols() - read_symbols -= symbols_that_may_be_used - used_before_assignment |= read_symbols - # If symbol was read before it is assigned, the loop cannot be parallel - assigned_symbols = set() - for k, v in e.data.assignments.items(): - try: - fsyms = symbolic.pystr_to_symbolic(v).free_symbols - except AttributeError: - fsyms = set() - if not k in fsyms: - assigned_symbols.add(k) - if assigned_symbols & used_before_assignment: - return False + for e in self.loop.all_interstate_edges(): + # Collect read-before-assigned symbols (this works because the states are always in order, + # see above call to `blockorder_topological_sort`) + read_symbols = e.data.read_symbols() + read_symbols -= symbols_that_may_be_used + used_before_assignment |= read_symbols + # If symbol was read before it is assigned, the loop cannot be parallel + assigned_symbols = set() + for k, v in e.data.assignments.items(): + try: + fsyms = symbolic.pystr_to_symbolic(v).free_symbols + except AttributeError: + fsyms = set() + if not k in fsyms: + assigned_symbols.add(k) + if assigned_symbols & used_before_assignment: + return False - symbols_that_may_be_used |= e.data.assignments.keys() + symbols_that_may_be_used |= e.data.assignments.keys() # Get access nodes from other states to isolate local loop variables other_access_nodes: Set[str] = set() - for state in sdfg.nodes(): - if state in states: + for state in sdfg.states(): + if state in loop_states: continue other_access_nodes |= set(n.data for n in state.data_nodes() if sdfg.arrays[n.data].transient) # Add non-transient nodes from loop state - for state in states: + for state in loop_states: other_access_nodes |= set(n.data for n in state.data_nodes() if not sdfg.arrays[n.data].transient) write_memlets: Dict[str, List[memlet.Memlet]] = defaultdict(list) @@ -173,7 +144,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi a = sp.Wild('a', exclude=[itersym]) b = sp.Wild('b', exclude=[itersym]) - for state in states: + for state in loop_states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue @@ -196,7 +167,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi write_memlets[dn.data].append(e.data) # After looping over relevant writes, consider reads that may overlap - for state in states: + for state in loop_states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue @@ -212,8 +183,8 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Consider reads in inter-state edges (could be in assignments or in condition) isread_set: Set[memlet.Memlet] = set() - for s in states: - for e in sdfg.all_edges(s): + for s in loop_states: + for e in s.parent_graph.all_edges(s): isread_set |= set(e.data.get_read_memlets(sdfg.arrays)) for mmlt in isread_set: if mmlt.data in write_memlets: @@ -223,8 +194,13 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Check that the iteration variable and other symbols are not used on other edges or states # before they are reassigned - for state in in_order_states[loop_begin_idx + 1:]: - if state in states: + in_order_states = list(cfg_analysis.blockorder_topological_sort(sdfg, recursive=True, + ignore_nonstate_blocks=False)) + loop_idx = in_order_states.index(self.loop) + for state in in_order_states[loop_idx + 1:]: + if not isinstance(state, SDFGState): + continue + if state in loop_states: continue # Don't continue in this direction, as all loop symbols have been reassigned if not symbols_that_may_be_used: @@ -346,217 +322,116 @@ def _is_array_thread_local(self, name: str, itervar: str, sdfg: SDFG, states: Li return False return True - def apply(self, _, sdfg: sd.SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): from dace.sdfg.propagation import align_memlet # Obtain loop information - itervar, (start, end, step), (_, body_end) = self.loop_information(itervar=self.itervar) - states = self.loop_body() - body: sd.SDFGState = self.loop_begin - exit_state = self.exit_state - entry_edge = self.loop_condition_edge() - init_edge = self.loop_init_edge() - after_edge = self.loop_exit_edge() - condition_edge = self.loop_condition_edge() - increment_edge = self.loop_increment_edge() + itervar = self.loop.loop_variable + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) nsdfg = None # Nest loop-body states - if len(states) > 1: - - # Find read/write sets - read_set, write_set = set(), set() - for state in states: - rset, wset = state.read_and_write_sets() - read_set |= rset - write_set |= wset - # Add to write set also scalars between tasklets - for src_node in state.nodes(): - if not isinstance(src_node, nodes.Tasklet): - continue - for dst_node in state.nodes(): - if src_node is dst_node: - continue - if not isinstance(dst_node, nodes.Tasklet): - continue - for e in state.edges_between(src_node, dst_node): - if e.data.data and e.data.data in sdfg.arrays: - write_set.add(e.data.data) - # Add data from edges - for src in states: - for dst in states: - for edge in sdfg.edges_between(src, dst): - for s in edge.data.free_symbols: - if s in sdfg.arrays: - read_set.add(s) - - # Find NestedSDFG's unique data - rw_set = read_set | write_set - unique_set = set() - for name in rw_set: - if not sdfg.arrays[name].transient: + states = set(self.loop.all_states()) + # Find read/write sets + read_set, write_set = set(), set() + for state in self.loop.all_states(): + rset, wset = state.read_and_write_sets() + read_set |= rset + write_set |= wset + # Add to write set also scalars between tasklets + for src_node in state.nodes(): + if not isinstance(src_node, nodes.Tasklet): continue - found = False - for state in sdfg.states(): - if state in states: + for dst_node in state.nodes(): + if src_node is dst_node: continue - for node in state.nodes(): - if (isinstance(node, nodes.AccessNode) and node.data == name): - found = True - break - if not found and self._is_array_thread_local(name, itervar, sdfg, states): - unique_set.add(name) - - # Find NestedSDFG's connectors - read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} - write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} - - # Create NestedSDFG and add all loop-body states and edges - # Also, find defined symbols in NestedSDFG - fsymbols = set(sdfg.free_symbols) - new_body = sdfg.add_state('single_state_body') - nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body) - nsdfg.add_node(body, is_start_state=True) - body.parent = nsdfg - nexit_state = nsdfg.add_state('exit') - nsymbols = dict() - for state in states: - if state is body: - continue - nsdfg.add_node(state) - state.parent = nsdfg - for state in states: - if state is body: + if not isinstance(dst_node, nodes.Tasklet): + continue + for e in state.edges_between(src_node, dst_node): + if e.data.data and e.data.data in sdfg.arrays: + write_set.add(e.data.data) + # Add data from edges + for edge in self.loop.all_interstate_edges(): + for s in edge.data.free_symbols: + if s in sdfg.arrays: + read_set.add(s) + + # Find NestedSDFG's / Loop's unique data + rw_set = read_set | write_set + unique_set = set() + for name in rw_set: + if not sdfg.arrays[name].transient: + continue + found = False + for state in sdfg.states(): + if state in states: continue - for src, dst, data in sdfg.in_edges(state): - nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols}) - nsdfg.add_edge(src, dst, data) - nsdfg.add_edge(body_end, nexit_state, InterstateEdge()) - - increment_edge = None - - # Specific instructions for loop type - if self.expr_index <= 1: # Natural loop with guard - guard = self.loop_guard - - # Move guard -> body edge to guard -> new_body - for e in sdfg.edges_between(guard, body): - sdfg.remove_edge(e) - condition_edge = sdfg.add_edge(e.src, new_body, e.data) - # Move body_end -> guard edge to new_body -> guard - for e in sdfg.edges_between(body_end, guard): - sdfg.remove_edge(e) - increment_edge = sdfg.add_edge(new_body, e.dst, e.data) - - - elif 1 < self.expr_index <= 3: # Rotated loop - entrystate = self.entry_state - latch = self.loop_latch - - # Move entry edge to entry -> new_body - for src, dst, data, in sdfg.edges_between(entrystate, body): - init_edge = sdfg.add_edge(src, new_body, data) - - # Move body_end -> latch to new_body -> latch - for src, dst, data in sdfg.edges_between(latch, exit_state): - after_edge = sdfg.add_edge(new_body, dst, data) - - elif self.expr_index == 4: # Self-loop - entrystate = self.entry_state - - # Move entry edge to entry -> new_body - for src, dst, data in sdfg.edges_between(entrystate, body): - init_edge = sdfg.add_edge(src, new_body, data) - for src, dst, data in sdfg.edges_between(body, exit_state): - after_edge = sdfg.add_edge(new_body, dst, data) - - - # Delete loop-body states and edges from parent SDFG - sdfg.remove_nodes_from(states) - - # Add NestedSDFG arrays - for name in read_set | write_set: - nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) - nsdfg.arrays[name].transient = False - for name in unique_set: - nsdfg.arrays[name] = sdfg.arrays[name] - del sdfg.arrays[name] - - # Add NestedSDFG node - cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set) - if sdfg.parent: - for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): - if s not in cnode.symbol_mapping: - cnode.symbol_mapping[s] = m - nsdfg.add_symbol(s, sdfg.symbols[s]) - for name in read_set: - r = new_body.add_read(name) - new_body.add_edge(r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) - for name in write_set: - w = new_body.add_write(name) - new_body.add_edge(cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) - - # Fix SDFG symbols - for sym in sdfg.free_symbols - fsymbols: - if sym in sdfg.symbols: - del sdfg.symbols[sym] - for sym, dtype in nsymbols.items(): - nsdfg.symbols[sym] = dtype - - # Change body state reference - body = new_body + for node in state.nodes(): + if (isinstance(node, nodes.AccessNode) and node.data == name): + found = True + break + if not found and self._is_array_thread_local(name, itervar, sdfg, states): + unique_set.add(name) + + # Find NestedSDFG's connectors + read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} + write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} + + # Create NestedSDFG and add the loop contents to it. Gaher symbols defined in the NestedSDFG. + fsymbols = set(sdfg.free_symbols) + body = graph.add_state('single_state_body') + nsdfg = SDFG('loop_body', constants=sdfg.constants_prop, parent=body) + nsdfg.add_node(self.loop.start_block, is_start_block=True) + nsymbols = dict() + for block in self.loop.nodes(): + if block is self.loop.start_block: + continue + nsdfg.add_node(state) + for e in self.loop.edges(): + nsymbols.update({s: sdfg.symbols[s] for s in e.data.assignments.keys() if s in sdfg.symbols}) + nsdfg.add_edge(e.src, e.dst, e.data) + + # Add NestedSDFG arrays + for name in read_set | write_set: + nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) + nsdfg.arrays[name].transient = False + for name in unique_set: + nsdfg.arrays[name] = sdfg.arrays[name] + del sdfg.arrays[name] + + # Add NestedSDFG node + cnode = body.add_nested_sdfg(nsdfg, None, read_set, write_set) + if sdfg.parent: + for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): + if s not in cnode.symbol_mapping: + cnode.symbol_mapping[s] = m + nsdfg.add_symbol(s, sdfg.symbols[s]) + for name in read_set: + r = body.add_read(name) + body.add_edge(r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) + for name in write_set: + w = body.add_write(name) + body.add_edge(cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) + + # Fix SDFG symbols + for sym in sdfg.free_symbols - fsymbols: + if sym in sdfg.symbols: + del sdfg.symbols[sym] + for sym, dtype in nsymbols.items(): + nsdfg.symbols[sym] = dtype if (step < 0) == True: - # If step is negative, we have to flip start and end to produce a - # correct map with a positive increment + # If step is negative, we have to flip start and end to produce a correct map with a positive increment. start, end, step = end, start, -step - reentry_assignments = {k: v for k, v in condition_edge.data.assignments.items() if k != itervar} - - # If necessary, make a nested SDFG with assignments - symbols_to_remove = set() - if len(reentry_assignments) > 0: - nsdfg = helpers.nest_state_subgraph(sdfg, body, gr.SubgraphView(body, body.nodes())) - for sym in entry_edge.data.free_symbols: - if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: - continue - if sym in sdfg.symbols: - nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) - nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) - elif sym in sdfg.arrays: - if sym in nsdfg.sdfg.arrays: - raise NotImplementedError - rnode = body.add_read(sym) - nsdfg.add_in_connector(sym) - desc = copy.deepcopy(sdfg.arrays[sym]) - desc.transient = False - nsdfg.sdfg.add_datadesc(sym, desc) - body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) - for name, desc in nsdfg.sdfg.arrays.items(): - if desc.transient and not self._is_array_thread_local(name, itervar, nsdfg.sdfg, nsdfg.sdfg.states()): - odesc = copy.deepcopy(desc) - sdfg.arrays[name] = odesc - desc.transient = False - wnode = body.add_access(name) - nsdfg.add_out_connector(name) - body.add_edge(nsdfg, name, wnode, None, memlet.Memlet.from_array(name, odesc)) - - nstate = nsdfg.sdfg.node(0) - init_state = nsdfg.sdfg.add_state_before(nstate) - nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] - nisedge.data.assignments = reentry_assignments - symbols_to_remove = set(nisedge.data.assignments.keys()) - for k in nisedge.data.assignments.keys(): - if k in nsdfg.symbol_mapping: - del nsdfg.symbol_mapping[k] - condition_edge.data.assignments = {} - source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() # Check intermediate notes - intermediate_nodes = [] + intermediate_nodes: List[nodes.AccessNode] = [] for node in body.nodes(): if isinstance(node, nodes.AccessNode) and body.in_degree(node) > 0 and node not in sink_nodes: # Scalars written without WCR must be thread-local @@ -590,7 +465,7 @@ def apply(self, _, sdfg: sd.SDFG): # Direct edges among source and sink access nodes must pass through a tasklet. # We first gather them and handle them later. - direct_edges = set() + direct_edges: Set[gr.Edge[memlet.Memlet]] = set() for n1 in source_nodes: if not isinstance(n1, nodes.AccessNode): continue @@ -623,7 +498,7 @@ def apply(self, _, sdfg: sd.SDFG): body.add_edge_pair(exit, e.src, n, new_memlet, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) - intermediate_sinks = {} + intermediate_sinks: Dict[str, nodes.AccessNode] = {} for n in intermediate_nodes: if isinstance(sdfg.arrays[n.data], dt.View): continue @@ -636,8 +511,8 @@ def apply(self, _, sdfg: sd.SDFG): # Here we handle the direct edges among source and sink access nodes. for e in direct_edges: - src = e.src.data - dst = e.dst.data + src: str = e.src.data + dst: str = e.dst.data if e.data.subset.num_elements() == 1: t = body.add_tasklet(f"{n1}_{n2}", {'__inp'}, {'__out'}, "__out = __inp") src_conn, dst_conn = '__out', '__inp' @@ -667,34 +542,17 @@ def apply(self, _, sdfg: sd.SDFG): if not source_nodes and not sink_nodes: body.add_nedge(entry, exit, memlet.Memlet()) - # Get rid of the loop exit condition edge (it will be readded below) - sdfg.remove_edge(after_edge) - - # Remove the assignment on the edge to the guard - for e in [init_edge, increment_edge]: - if e is None: - continue - if itervar in e.data.assignments: - del e.data.assignments[itervar] + # Redirect edges connected to the loop to connect to the body state instead. + for e in graph.out_edges(self.loop): + graph.add_edge(body, e.dst, e.data) + for e in graph.in_edges(self.loop): + graph.add_edge(e.src, body, e.data) + # Delete the loop and connected edges. + graph.remove_node(self.loop) - # Remove the condition on the entry edge - condition_edge.data.condition = CodeBlock("1") - - # Get rid of backedge to guard - if increment_edge is not None: - sdfg.remove_edge(increment_edge) - - # Route body directly to after state, maintaining any other assignments - # it might have had - sdfg.add_edge(body, exit_state, sd.InterstateEdge(assignments=after_edge.data.assignments)) - - # If this had made the iteration variable a free symbol, we can remove - # it from the SDFG symbols + # If this had made the iteration variable a free symbol, we can remove it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) - for sym in symbols_to_remove: - if sym in sdfg.symbols and helpers.is_symbol_unused(sdfg, sym): - sdfg.remove_symbol(sym) # Reset all nested SDFG parent pointers if nsdfg is not None: @@ -707,3 +565,5 @@ def apply(self, _, sdfg: sd.SDFG): nnode.sdfg.parent_nsdfg_node = nnode nnode.sdfg.parent = nstate nnode.sdfg.parent_sdfg = nsdfg + + sdfg.reset_cfg_list() diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 12d4898858..af945a34fc 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -12,6 +12,7 @@ from dace.sdfg import nodes, propagation from dace.transformation.interstate import LoopToMap, StateFusion from dace.transformation.interstate.loop_detection import DetectLoop +from dace.transformation.interstate.loop_lifting import LoopLifting def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after, log_path): @@ -87,6 +88,8 @@ def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after, post_tasklet = post.add_tasklet("post", {}, {"e"}, "e = i" if use_variable else "e = N") post.add_memlet_path(post_tasklet, e, src_conn="e", memlet=dace.Memlet("E[0]")) + sdfg.apply_transformations_repeated([LoopLifting]) + return sdfg @@ -285,6 +288,7 @@ def test_interstate_dep(): ref = np.random.randint(0, 10, size=(10, ), dtype=np.int32) val = np.copy(ref) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg(A=ref) assert sdfg.apply_transformations(LoopToMap) == 0 @@ -294,7 +298,7 @@ def test_interstate_dep(): def test_need_for_tasklet(): - + # Note: Since the introduction of loop regions, this no longer requires a tasklet. sdfg = dace.SDFG('needs_tasklet') aname, _ = sdfg.add_array('A', (10, ), dace.int32) bname, _ = sdfg.add_array('B', (10, ), dace.int32) @@ -304,14 +308,8 @@ def test_need_for_tasklet(): bnode = body.add_access(bname) body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='i', other_subset='9 - i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) - found = False - for n, s in sdfg.all_nodes_recursive(): - if isinstance(n, nodes.Tasklet): - found = True - break - - assert found A = np.arange(10, dtype=np.int32) B = np.empty((10, ), dtype=np.int32) @@ -331,6 +329,7 @@ def test_need_for_transient(): bnode = body.add_access(bname) body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='0:10, i', other_subset='0:10, 9 - i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) found = False for n, s in sdfg.all_nodes_recursive(): From 352171a4db97d9ba9bbfb704e60f4d76e7b83076 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 9 Oct 2024 12:07:33 +0200 Subject: [PATCH 025/108] Fixes in loop to map --- dace/sdfg/state.py | 12 +++ dace/transformation/interstate/loop_to_map.py | 74 ++++++++++--------- tests/transformations/loop_to_map_test.py | 22 +++--- 3 files changed, 63 insertions(+), 45 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 6d2a219aab..2036ea7940 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -960,6 +960,18 @@ def nodes(self) -> List['ControlFlowBlock']: def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: ... + @overload + def in_edges(self, node: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + @overload + def out_edges(self, node: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + @overload + def all_edges(self, node: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + ################################################################### # Traversal methods diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 245fff69ab..595bf7265e 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -99,34 +99,39 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): if symbolic.contains_sympy_functions(expr): return False + _, write_set = self.loop.read_and_write_sets() loop_states = set(self.loop.all_states()) - write_set: Set[str] = set() - for state in loop_states: - _, wset = state.read_and_write_sets() - write_set |= wset + all_loop_blocks = set(self.loop.all_control_flow_blocks()) + #write_set: Set[str] = set() + #for block in loop_states: + # _, wset = block.read_and_write_sets() + # write_set |= wset # Collect symbol reads and writes from inter-state assignments + in_order_loop_blocks = list(cfg_analysis.blockorder_topological_sort(self.loop, recursive=True, + ignore_nonstate_blocks=False)) symbols_that_may_be_used: Set[str] = {itervar} used_before_assignment: Set[str] = set() - for e in self.loop.all_interstate_edges(): - # Collect read-before-assigned symbols (this works because the states are always in order, - # see above call to `blockorder_topological_sort`) - read_symbols = e.data.read_symbols() - read_symbols -= symbols_that_may_be_used - used_before_assignment |= read_symbols - # If symbol was read before it is assigned, the loop cannot be parallel - assigned_symbols = set() - for k, v in e.data.assignments.items(): - try: - fsyms = symbolic.pystr_to_symbolic(v).free_symbols - except AttributeError: - fsyms = set() - if not k in fsyms: - assigned_symbols.add(k) - if assigned_symbols & used_before_assignment: - return False + for block in in_order_loop_blocks: + for e in block.parent_graph.out_edges(block): + # Collect read-before-assigned symbols (this works because the states are always in order, + # see above call to `blockorder_topological_sort`) + read_symbols = e.data.read_symbols() + read_symbols -= symbols_that_may_be_used + used_before_assignment |= read_symbols + # If symbol was read before it is assigned, the loop cannot be parallel + assigned_symbols = set() + for k, v in e.data.assignments.items(): + try: + fsyms = symbolic.pystr_to_symbolic(v).free_symbols + except AttributeError: + fsyms = set() + if not k in fsyms: + assigned_symbols.add(k) + if assigned_symbols & used_before_assignment: + return False - symbols_that_may_be_used |= e.data.assignments.keys() + symbols_that_may_be_used |= e.data.assignments.keys() # Get access nodes from other states to isolate local loop variables other_access_nodes: Set[str] = set() @@ -183,36 +188,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): # Consider reads in inter-state edges (could be in assignments or in condition) isread_set: Set[memlet.Memlet] = set() - for s in loop_states: - for e in s.parent_graph.all_edges(s): - isread_set |= set(e.data.get_read_memlets(sdfg.arrays)) + for e in self.loop.all_interstate_edges(): + isread_set |= set(e.data.get_read_memlets(sdfg.arrays)) for mmlt in isread_set: if mmlt.data in write_memlets: if not self.test_read_memlet(sdfg, None, None, itersym, itervar, start, end, step, write_memlets, mmlt, mmlt.subset): return False - # Check that the iteration variable and other symbols are not used on other edges or states - # before they are reassigned - in_order_states = list(cfg_analysis.blockorder_topological_sort(sdfg, recursive=True, + # Check that the iteration variable and other symbols are not used on other edges or blocks before they are + # reassigned. + in_order_blocks = list(cfg_analysis.blockorder_topological_sort(sdfg, recursive=True, ignore_nonstate_blocks=False)) - loop_idx = in_order_states.index(self.loop) - for state in in_order_states[loop_idx + 1:]: - if not isinstance(state, SDFGState): - continue - if state in loop_states: + loop_idx = in_order_blocks.index(self.loop) + for block in in_order_blocks[loop_idx + 1:]: + if block in all_loop_blocks: continue # Don't continue in this direction, as all loop symbols have been reassigned if not symbols_that_may_be_used: break # Check state contents - if symbols_that_may_be_used & state.free_symbols: + if symbols_that_may_be_used & block.free_symbols: return False # Check inter-state edges reassigned_symbols: Set[str] = None - for e in sdfg.out_edges(state): + for e in block.parent_graph.out_edges(block): if symbols_that_may_be_used & e.data.read_symbols(): return False @@ -389,7 +391,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): for block in self.loop.nodes(): if block is self.loop.start_block: continue - nsdfg.add_node(state) + nsdfg.add_node(block) for e in self.loop.edges(): nsymbols.update({s: sdfg.symbols[s] for s in e.data.assignments.keys() if s in sdfg.symbols}) nsdfg.add_edge(e.src, e.dst, e.data) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index af945a34fc..c342dedae0 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -298,7 +298,8 @@ def test_interstate_dep(): def test_need_for_tasklet(): - # Note: Since the introduction of loop regions, this no longer requires a tasklet. + # Note: Since the introduction of loop regions this no longer requires a tasklet, as the nested SDFG is directly + # equivalent to the loop region, including all direct access node to access node copy operations. sdfg = dace.SDFG('needs_tasklet') aname, _ = sdfg.add_array('A', (10, ), dace.int32) bname, _ = sdfg.add_array('B', (10, ), dace.int32) @@ -319,7 +320,8 @@ def test_need_for_tasklet(): def test_need_for_transient(): - + # Note: Since the introduction of loop regions this no longer requires a transient, as the nested SDFG is directly + # equivalent to the loop region, including all direct access node to access node copy operations. sdfg = dace.SDFG('needs_transient') aname, _ = sdfg.add_array('A', (10, 10), dace.int32) bname, _ = sdfg.add_array('B', (10, 10), dace.int32) @@ -331,13 +333,6 @@ def test_need_for_transient(): sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) - found = False - for n, s in sdfg.all_nodes_recursive(): - if isinstance(n, nodes.AccessNode) and n.data not in (aname, bname): - found = True - break - - assert found A = np.arange(100, dtype=np.int32).reshape(10, 10).copy() B = np.empty((10, 10), dtype=np.int32) @@ -402,6 +397,7 @@ def test_symbol_write_before_read(): sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(j='0'))) sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(j='j + 1'))) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == 1 @@ -429,6 +425,7 @@ def test_symbol_array_mix(overwrite): sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(sym='sym + tmp'))) sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(sym='sym + 1.0'))) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) @@ -455,6 +452,7 @@ def test_symbol_array_mix_2(parallel): t = body_start.add_tasklet('use', {}, {'o'}, 'o = sym') body_start.add_edge(t, 'o', body_start.add_write('B'), None, dace.Memlet('B[i]')) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == (1 if parallel else 0) @@ -481,6 +479,7 @@ def test_internal_symbol_used_outside(overwrite): else: sdfg.add_edge(after, after_1, dace.InterstateEdge()) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) @@ -510,6 +509,7 @@ def test_shared_local_transient_single_state(): body.add_edge(t1, '__out', anode, None, dace.Memlet(data='A', subset='i')) body.add_edge(anode, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert 'A' in sdfg.arrays @@ -549,6 +549,7 @@ def test_thread_local_transient_single_state(): body.add_edge(anode, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert not ('A' in sdfg.arrays) @@ -587,6 +588,7 @@ def test_shared_local_transient_multi_state(): body1.add_edge(anode1, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body1.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert 'A' in sdfg.arrays @@ -628,6 +630,7 @@ def test_thread_local_transient_multi_state(): body1.add_edge(anode1, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body1.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert not ('A' in sdfg.arrays) @@ -748,6 +751,7 @@ def test_rotated_loop_to_map(simplify): t = body.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]')) body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]')) + sdfg.apply_transformations_repeated([LoopLifting]) if simplify: sdfg.apply_transformations_repeated(StateFusion) From bbddc88bf48b76f65f1443019bc1eaf092e673e8 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 9 Oct 2024 19:25:00 +0200 Subject: [PATCH 026/108] Finished LoopToMap --- dace/sdfg/graph.py | 2 +- dace/sdfg/propagation.py | 27 +++++----- dace/sdfg/state.py | 51 ++++++++++++++----- .../transformation/interstate/loop_lifting.py | 2 +- dace/transformation/interstate/loop_to_map.py | 12 ----- .../passes/analysis/analysis.py | 2 +- .../passes/dead_dataflow_elimination.py | 2 +- dace/transformation/passes/simplify.py | 4 +- dace/transformation/transformation.py | 8 +-- tests/transformations/loop_to_map_test.py | 38 ++++++-------- 10 files changed, 78 insertions(+), 70 deletions(-) diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index 5ec4bbb029..a154868fd3 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -680,7 +680,7 @@ def add_edge(self, src: NodeT, dst: NodeT, data: EdgeT = None): def remove_node(self, node: NodeT): try: - for edge in itertools.chain(self.in_edges(node), self.out_edges(node)): + for edge in self.all_edges(node): self.remove_edge(edge) del self._nodes[node] self._nx.remove_node(node) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index f62bb6eb58..e502db6166 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -9,7 +9,7 @@ import itertools import warnings from collections import deque -from typing import List, Set +from typing import TYPE_CHECKING, List, Set import sympy from sympy import Symbol, ceiling @@ -22,6 +22,11 @@ from dace.symbolic import issymbolic, pystr_to_symbolic, simplify +if TYPE_CHECKING: + from dace.sdfg import SDFG + from dace.sdfg.state import SDFGState + + @registry.make_registry class MemletPattern(object): """ @@ -561,7 +566,7 @@ def propagate(self, array, expressions, node_range): return subsets.Range(rng) -def _annotate_loop_ranges(sdfg, unannotated_cycle_states): +def _annotate_loop_ranges(sdfg: 'SDFG', unannotated_cycle_states): """ Annotate each valid for loop construct with its loop variable ranges. @@ -682,7 +687,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): return condition_edges -def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: +def propagate_states(sdfg: 'SDFG', concretize_dynamic_unbounded: bool = False) -> None: """ Annotate the states of an SDFG with the number of executions. @@ -948,7 +953,7 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: sdfg.remove_node(temp_exit_state) -def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): +def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState', nsdfg_node: nodes.NestedSDFG): """ Propagate memlets out of a nested sdfg. @@ -980,7 +985,7 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): # the corresponding memlets and use them to calculate the memlet volume and # subset corresponding to the outside memlet attached to that connector. # This is passed out via `border_memlets` and propagated along from there. - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.data_nodes(): for direction in border_memlets: if (node.label not in border_memlets[direction]): @@ -1139,20 +1144,18 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): oedge.data.dynamic = True -def reset_state_annotations(sdfg): +def reset_state_annotations(sdfg: 'SDFG'): """ Resets the state (loop-related) annotations of an SDFG. :note: This operation is shallow (does not go into nested SDFGs). """ - for state in sdfg.nodes(): + for state in sdfg.states(): state.executions = 0 state.dynamic_executions = True state.ranges = {} - state.is_loop_guard = False - state.itervar = None -def propagate_memlets_sdfg(sdfg): +def propagate_memlets_sdfg(sdfg: 'SDFG'): """ Propagates memlets throughout an entire given SDFG. :note: This is an in-place operation on the SDFG. @@ -1160,13 +1163,13 @@ def propagate_memlets_sdfg(sdfg): # Reset previous annotations first reset_state_annotations(sdfg) - for state in sdfg.nodes(): + for state in sdfg.states(): propagate_memlets_state(sdfg, state) propagate_states(sdfg) -def propagate_memlets_state(sdfg, state): +def propagate_memlets_state(sdfg: 'SDFG', state: 'SDFGState'): """ Propagates memlets throughout one SDFG state. :param sdfg: The SDFG in which the state is situated. diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 2036ea7940..abd325da53 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2501,13 +2501,14 @@ def sdfg(self) -> 'SDFG': @make_properties -class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, - ControlFlowBlock): +class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, + ControlFlowBlock, abc.ABC): - def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None): + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, + parent: Optional['AbstractControlFlowRegion'] = None): OrderedDiGraph.__init__(self) ControlGraphView.__init__(self) - ControlFlowBlock.__init__(self, label, sdfg) + ControlFlowBlock.__init__(self, label, sdfg, parent) self._labels: Set[str] = set() self._start_block: Optional[int] = None @@ -2683,9 +2684,13 @@ def add_node(self, self._cached_start_block = None node.parent_graph = self if isinstance(self, dace.SDFG): - node.sdfg = self + sdfg = self else: - node.sdfg = self.sdfg + sdfg = self.sdfg + node.sdfg = sdfg + if isinstance(node, AbstractControlFlowRegion): + for n in node.all_control_flow_blocks(): + n.sdfg = self.sdfg start_block = is_start_block if is_start_state is not None: warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) @@ -2963,6 +2968,13 @@ def start_block(self, block_id): self._cached_start_block = self.node(block_id) +@make_properties +class ControlFlowRegion(AbstractControlFlowRegion): + + def __init__(self, label = '', sdfg = None, parent = None): + super().__init__(label, sdfg, parent) + + @make_properties class LoopRegion(ControlFlowRegion): """ @@ -3244,7 +3256,7 @@ def has_return(self) -> bool: @make_properties -class ConditionalBlock(ControlFlowBlock, ControlGraphView): +class ConditionalBlock(AbstractControlFlowRegion): _branches: List[Tuple[Optional[CodeBlock], ControlFlowRegion]] @@ -3264,7 +3276,7 @@ def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): self._branches.append([condition, branch]) - branch.parent_graph = self.parent_graph + branch.parent_graph = self branch.sdfg = self.sdfg def remove_branch(self, branch: ControlFlowRegion): @@ -3273,12 +3285,6 @@ def remove_branch(self, branch: ControlFlowRegion): if b is not branch: filtered_branches.append((c, b)) self._branches = filtered_branches - - def nodes(self) -> List['ControlFlowBlock']: - return [node for _, node in self._branches if node is not None] - - def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: - return [] def _used_symbols_internal(self, all_symbols: bool, @@ -3397,6 +3403,23 @@ def inline(self) -> Tuple[bool, Any]: return True, (guard_state, end_state) + # Graph API overrides. + + def nodes(self) -> List['ControlFlowBlock']: + return [node for _, node in self._branches if node is not None] + + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def in_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def out_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def all_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + @make_properties class NamedRegion(ControlFlowRegion): diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 0c64c1548a..b4ffcd5c93 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -60,7 +60,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): if k != itvar: left_over_incr_assignments[k] = incr_edge.data.assignments[k] - if inverted and incr_edge is cond_edge: + if (inverted or self.expr_index == 4) and incr_edge is cond_edge: update_before_condition = False else: update_before_condition = True diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 595bf7265e..fed3118481 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -556,16 +556,4 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) - # Reset all nested SDFG parent pointers - if nsdfg is not None: - if isinstance(nsdfg, nodes.NestedSDFG): - nsdfg = nsdfg.sdfg - - for nstate in nsdfg.nodes(): - for nnode in nstate.nodes(): - if isinstance(nnode, nodes.NestedSDFG): - nnode.sdfg.parent_nsdfg_node = nnode - nnode.sdfg.parent = nstate - nnode.sdfg.parent_sdfg = nsdfg - sdfg.reset_cfg_list() diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 45153c23b4..d8c791f4e5 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -254,7 +254,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): fsyms = e.data.free_symbols & anames if fsyms: result[e.src][0].update(fsyms) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index c95aee665b..964d362d99 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -41,7 +41,7 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If dataflow or states changed, new dead code may be exposed - return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.States) + return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.CFG) def depends_on(self) -> Set[Type[ppl.Pass]]: return {ap.StateReachability, ap.AccessSets} diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 81e8e88362..bd28f8d377 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass from typing import Any, Dict, Optional, Set import warnings @@ -16,10 +16,12 @@ from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols from dace.transformation.passes.reference_reduction import ReferenceToView +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising SIMPLIFY_PASSES = [ InlineSDFGs, ScalarToSymbolPromotion, + ControlFlowRaising, FuseStates, OptionalArrayInference, ConstantPropagation, diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 727ec5555b..3b89612026 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This file contains classes that describe data-centric transformations. @@ -20,10 +20,10 @@ import abc import copy -from dace import dtypes, serialize +from dace import serialize from dace.dtypes import ScheduleType from dace.sdfg import SDFG, SDFGState -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st from dace.properties import make_properties, Property, DictProperty, SetProperty from dace.transformation import pass_pipeline as ppl @@ -339,7 +339,7 @@ def _can_be_applied_and_apply( # Check that all keyword arguments are nodes and if interstate or not sample_node = next(iter(where.values())) - if isinstance(sample_node, SDFGState): + if isinstance(sample_node, ControlFlowBlock): graph = sample_node.parent_graph state_id = -1 cfg_id = graph.cfg_id diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index c342dedae0..5f4b5c66f9 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -3,15 +3,14 @@ import copy import os import tempfile -from typing import Tuple import numpy as np import pytest import dace -from dace.sdfg import nodes, propagation +from dace.sdfg import nodes +from dace.sdfg.state import LoopRegion from dace.transformation.interstate import LoopToMap, StateFusion -from dace.transformation.interstate.loop_detection import DetectLoop from dace.transformation.interstate.loop_lifting import LoopLifting @@ -653,34 +652,25 @@ def nested_loops(A: dace.int32[10, 10, 10], l: dace.int32): sdfg = nested_loops.to_sdfg() - def find_loop(sdfg: dace.SDFG, itervar: str) -> Tuple[dace.SDFGState, dace.SDFGState, dace.SDFGState]: - - guard, begin, fexit = None, None, None - for e in sdfg.edges(): - if itervar in e.data.assignments and e.data.assignments[itervar] == '0': - guard = e.dst - elif e.data.condition.as_string in (f'({itervar} >= 10)', f'(not ({itervar} < 10))'): - fexit = e.dst - assert all(s is not None for s in (guard, fexit)) - - begin = next((e for e in sdfg.out_edges(guard) if e.dst != fexit)).dst - - return guard, begin, fexit + def find_loop(sdfg: dace.SDFG, itervar: str) -> LoopRegion: + for cfg in sdfg.all_control_flow_regions(): + if isinstance(cfg, LoopRegion) and cfg.loop_variable == itervar: + return cfg sdfg0 = copy.deepcopy(sdfg) - i_guard, i_begin, i_exit = find_loop(sdfg0, 'i') - LoopToMap.apply_to(sdfg0, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit) + i_loop = find_loop(sdfg0, 'i') + LoopToMap.apply_to(sdfg0, loop=i_loop) nsdfg = next((sd for sd in sdfg0.all_sdfgs_recursive() if sd.parent is not None)) - j_guard, j_begin, j_exit = find_loop(nsdfg, 'j') - LoopToMap.apply_to(nsdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit) + j_loop = find_loop(nsdfg, 'j') + LoopToMap.apply_to(nsdfg, loop=j_loop) val = np.arange(1000, dtype=np.int32).reshape(10, 10, 10).copy() sdfg(A=val, l=5) assert np.allclose(ref, val) - j_guard, j_begin, j_exit = find_loop(sdfg, 'j') - LoopToMap.apply_to(sdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit) + j_loop = find_loop(sdfg, 'j') + LoopToMap.apply_to(sdfg, loop=j_loop) # NOTE: The following fails to apply because of subset A[0:i+1], which is overapproximated. # i_guard, i_begin, i_exit = find_loop(sdfg, 'i') # LoopToMap.apply_to(sdfg, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit) @@ -720,7 +710,7 @@ def internal_write(inp0: dace.int32[10], inp1: dace.int32[10], out: dace.int32[1 val = np.empty((10, ), dtype=np.int32) internal_write.f(inp0, inp1, ref) - internal_write(inp0, inp1, val) + sdfg(inp0, inp1, val) assert np.array_equal(val, ref) @@ -782,6 +772,8 @@ def test_self_loop_to_map(): body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]')) body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]')) + sdfg.apply_transformations_repeated([LoopLifting]) + assert sdfg.apply_transformations_repeated(LoopToMap) == 1 a = np.random.rand(20) From 5ef1f461cd05e29af06f7c5e3d9839dc52b867cf Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 9 Oct 2024 21:36:26 +0200 Subject: [PATCH 027/108] Adapt loop unrolling --- dace/transformation/interstate/loop_unroll.py | 173 +++++++++--------- .../passes/analysis/loop_analysis.py | 8 +- .../transformation/subgraph/stencil_tiling.py | 7 +- .../transformations/loop_manipulation_test.py | 4 +- 4 files changed, 90 insertions(+), 102 deletions(-) diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 663745c0d6..ff35b9f4a5 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -1,135 +1,128 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop unroll transformation """ import copy -from typing import List +from typing import List, Optional from dace import sdfg as sd, symbolic -from dace.properties import Property, make_properties -from dace.sdfg import graph as gr +from dace.properties import CodeBlock, Property, make_properties from dace.sdfg import utils as sdutil -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.frontend.python.astutils import ASTFindReplace -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from dace.transformation import transformation as xf +from dace.transformation.passes.analysis import loop_analysis @make_properties @xf.experimental_cfg_block_compatible -class LoopUnroll(DetectLoop, xf.MultiStateTransformation): +class LoopUnroll(xf.MultiStateTransformation): """ Unrolls a state machine for-loop into multiple states """ + loop = xf.PatternNode(LoopRegion) + count = Property( dtype=int, default=0, - desc='Number of iterations to unroll, or zero for all ' - 'iterations (loop must be constant-sized for 0)', + desc='Number of iterations to unroll, or zero for all iterations (loop must be constant-sized for 0)', ) - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + inline_iterations = Property(dtype=bool, default=True, + desc='Whether or not to inline individual iteration\'s CFGs after unrolling') - found = self.loop_information() + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # If loop cannot be detected, fail - if not found: + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # If loop information cannot be determined, fail. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) + itervar = self.loop.loop_variable + if start is None or end is None or step is None or itervar is None: return False - _, rng, _ = found # If loop stride is not specialized or constant-sized, fail - if symbolic.issymbolic(rng[2], sdfg.constants): + if symbolic.issymbolic(step, sdfg.constants): return False # If loop range diff is not constant-sized, fail - if symbolic.issymbolic(rng[1] - rng[0], sdfg.constants): + if symbolic.issymbolic(end - start, sdfg.constants): return False return True def apply(self, graph: ControlFlowRegion, sdfg): # Obtain loop information - begin: sd.SDFGState = self.loop_begin - after_state: sd.SDFGState = self.exit_state - - # Obtain iteration variable, range, and stride, together with the last - # state(s) before the loop and the last loop state. - itervar, rng, loop_struct = self.loop_information() - - # Loop must be fully unrollable for now. - if self.count != 0: - raise NotImplementedError # TODO(later) - - # Get loop states - loop_states = self.loop_body() - first_id = loop_states.index(begin) - last_state = loop_struct[1] - last_id = loop_states.index(last_state) - loop_subgraph = gr.SubgraphView(graph, loop_states) + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + stride = loop_analysis.get_loop_stride(self.loop) try: - start, end, stride = (r for r in rng) stride = symbolic.evaluate(stride, sdfg.constants) loop_diff = int(symbolic.evaluate(end - start + 1, sdfg.constants)) - is_symbolic = any([symbolic.issymbolic(r) for r in rng[:2]]) + is_symbolic = any([symbolic.issymbolic(r) for r in (start, end)]) except TypeError: raise TypeError('Loop difference and strides cannot be symbolic.') - # Create states for loop subgraph - unrolled_states = [] + # Create states for loop subgraph + unrolled_iterations: List[ControlFlowRegion] = [] for i in range(0, loop_diff, stride): + if self.count != 0 and i >= self.count: + break + + # Instantiate loop contents as a new control flow region with iterate value. current_index = start + i - # Instantiate loop states with iterate value - new_states = self.instantiate_loop(sdfg, loop_states, loop_subgraph, itervar, current_index, - str(i) if is_symbolic else None) + iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, + str(i) if is_symbolic else None) # Connect iterations with unconditional edges - if len(unrolled_states) > 0: - graph.add_edge(unrolled_states[-1][1], new_states[first_id], sd.InterstateEdge()) - - unrolled_states.append((new_states[first_id], new_states[last_id])) - - # Get any assignments that might be on the edge to the after state - after_assignments = self.loop_exit_edge().data.assignments - - # Connect new states to before and after states without conditions - if unrolled_states: - before_states = loop_struct[0] - for before_state in before_states: - graph.add_edge(before_state, unrolled_states[0][0], sd.InterstateEdge()) - graph.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments)) - - # Remove old states from SDFG - guard_or_latch = self.loop_meta_states() - graph.remove_nodes_from(guard_or_latch + loop_states) - - def instantiate_loop( - self, - sdfg: sd.SDFG, - loop_states: List[sd.SDFGState], - loop_subgraph: gr.SubgraphView, - itervar: str, - value: symbolic.SymbolicType, - state_suffix=None, - ): - # Using to/from JSON copies faster than deepcopy (which will also - # copy the parent SDFG) - new_states = [sd.SDFGState.from_json(s.to_json(), context={'sdfg': sdfg}) for s in loop_states] - - # Replace iterate with value in each state - for state in new_states: - state.label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) - state.replace(itervar, value) - - graph = loop_states[0].parent_graph - # Add subgraph to original SDFG - for edge in loop_subgraph.edges(): - src = new_states[loop_states.index(edge.src)] - dst = new_states[loop_states.index(edge.dst)] + if len(unrolled_iterations) > 0: + graph.add_edge(unrolled_iterations[-1], iteration_region, sd.InterstateEdge()) + unrolled_iterations.append(iteration_region) + if self.count != 0: + # Not all iterations of the loop were unrolled. Connect the unrolled iterations accordingly and adjust the + # remaining loop bounds. + if unrolled_iterations: + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, unrolled_iterations[0], ie.data) + graph.add_edge(unrolled_iterations[-1], self.loop, sd.InterstateEdge()) + + new_start = symbolic.evaluate(start + (self.count * stride), sdfg.constants) + self.loop.init_statement = CodeBlock(f'{self.loop.loop_variable} = {new_start}') + else: + # Everything was unrolled. + # Connect the unrolled iterations to the rest of the graph. + if unrolled_iterations: + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, unrolled_iterations[0], ie.data) + for oe in graph.out_edges(self.loop): + graph.add_edge(unrolled_iterations[-1], oe.dst, oe.data) + + # Remove old loop. + graph.remove_node(self.loop) + + if self.inline_iterations: + for it in unrolled_iterations: + it.inline() + + def instantiate_loop_iteration(self, graph: ControlFlowRegion, loop: LoopRegion, value: symbolic.SymbolicType, + label_suffix: Optional[str] = None) -> ControlFlowRegion: + it_label = loop.label + '_' + loop.loop_variable + (label_suffix if label_suffix is not None else str(value)) + iteration_region = ControlFlowRegion(it_label, graph.sdfg, graph) + graph.add_node(iteration_region) + block_map = {} + for block in loop.nodes(): + # Using to/from JSON copies faster than deepcopy. + new_block = sd.SDFGState.from_json(block.to_json(), context={'sdfg': graph.sdfg}) + block_map[block] = new_block + new_block.replace(loop.loop_variable, value) + iteration_region.add_node(new_block, is_start_block=(block is loop.start_block)) + for edge in loop.edges(): + src = block_map[edge.src] + dst = block_map[edge.dst] # Replace conditions in subgraph edges - data: sd.InterstateEdge = copy.deepcopy(edge.data) + data = copy.deepcopy(edge.data) if not data.is_unconditional(): - ASTFindReplace({itervar: str(value)}).visit(data.condition) - - graph.add_edge(src, dst, data) + ASTFindReplace({loop.loop_variable: str(value)}).visit(data.condition) + iteration_region.add_edge(src, dst, data) - return new_states + return iteration_region diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 9dbfb62e9f..806df4d635 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -1,16 +1,12 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast -from collections import defaultdict -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, Optional import sympy -from dace import SDFG, properties, symbolic, transformation -from dace.memlet import Memlet +from dace import symbolic from dace.sdfg.state import LoopRegion -from dace.subsets import Range, SubsetUnion -from dace.transformation import pass_pipeline as ppl class FindAssignment(ast.NodeVisitor): diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 1ba86252c4..018bc723f3 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -2,10 +2,8 @@ """ This module contains classes and functions that implement the orthogonal stencil tiling transformation. """ -import math - import dace -from dace import dtypes, registry, symbolic +from dace import dtypes, symbolic from dace.properties import make_properties, Property, ShapeProperty from dace.sdfg import nodes from dace.transformation import transformation @@ -15,7 +13,6 @@ from dace.transformation.dataflow.map_expansion import MapExpansion from dace.transformation.dataflow.map_collapse import MapCollapse from dace.transformation.dataflow.strip_mining import StripMining -from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.interstate.loop_detection import DetectLoop from dace.transformation.subgraph import SubgraphFusion @@ -573,6 +570,8 @@ def apply(self, sdfg): nsdfg = trafo_for_loop.nsdfg # LoopUnroll + # Prevent circular import + from dace.transformation.interstate.loop_unroll import LoopUnroll guard = trafo_for_loop.guard end = trafo_for_loop.after_state diff --git a/tests/transformations/loop_manipulation_test.py b/tests/transformations/loop_manipulation_test.py index dbeed91464..292a451510 100644 --- a/tests/transformations/loop_manipulation_test.py +++ b/tests/transformations/loop_manipulation_test.py @@ -27,9 +27,9 @@ def regression(A, B): def test_unroll(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 2 sdfg.apply_transformations(LoopUnroll) - assert len(sdfg.nodes()) == (5 + 2) + assert len(sdfg.nodes()) == 1 + 5 * 2 sdfg.simplify() assert len(sdfg.nodes()) == 1 A = np.random.rand(20) From 2d26e0f9c5aa90a946fe1997dd1f41f1fc13c68e Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 10 Oct 2024 18:30:00 +0200 Subject: [PATCH 028/108] Adapt loop peeling and unrolling --- .../transformation/interstate/loop_peeling.py | 153 ++++++------------ dace/transformation/interstate/loop_unroll.py | 40 ++--- .../transformations/loop_manipulation_test.py | 8 +- 3 files changed, 69 insertions(+), 132 deletions(-) diff --git a/dace/transformation/interstate/loop_peeling.py b/dace/transformation/interstate/loop_peeling.py index c2e50cd37a..710f6f5d97 100644 --- a/dace/transformation/interstate/loop_peeling.py +++ b/dace/transformation/interstate/loop_peeling.py @@ -1,17 +1,16 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" Loop unroll transformation """ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Loop peeling transformation """ import sympy as sp -from typing import Optional +from typing import List, Optional from dace import sdfg as sd +from dace import symbolic from dace.sdfg.state import ControlFlowRegion from dace.properties import Property, make_properties, CodeBlock -from dace.sdfg import graph as gr -from dace.sdfg import utils as sdutil from dace.symbolic import pystr_to_symbolic -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from dace.transformation.interstate.loop_unroll import LoopUnroll +from dace.transformation.passes.analysis import loop_analysis from dace.transformation.transformation import experimental_cfg_block_compatible @@ -19,29 +18,18 @@ @experimental_cfg_block_compatible class LoopPeeling(LoopUnroll): """ - Splits the first `count` iterations of a state machine for-loop into - multiple, separate states. + Splits the first `count` iterations of loop into multiple, separate control flow regions (one per iteration). """ begin = Property( dtype=bool, default=True, - desc='If True, peels loop from beginning (first `count` ' - 'iterations), otherwise peels last `count` iterations.', + desc='If True, peels loop from beginning (first `count` iterations), otherwise peels last `count` iterations.', ) def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False - - guard = self.loop_guard - begin = self.loop_begin - - # If loop cannot be detected, fail - found = find_for_loop(sdfg, guard, begin) - if found is None: - return False - return True def _modify_cond(self, condition, var, step): @@ -77,90 +65,55 @@ def _modify_cond(self, condition, var, step): return res def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): - #################################################################### # Obtain loop information - begin: sd.SDFGState = self.loop_begin - after_state: sd.SDFGState = self.exit_state - - # Obtain iteration variable, range, and stride - condition_edge = self.loop_condition_edge() - not_condition_edge = self.loop_exit_edge() - itervar, rng, loop_struct = self.loop_information() - - # Get loop states - loop_states = self.loop_body() - first_id = loop_states.index(begin) - last_state = loop_struct[1] - last_id = loop_states.index(last_state) - loop_subgraph = gr.SubgraphView(graph, loop_states) - - #################################################################### - # Transform + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + stride = loop_analysis.get_loop_stride(self.loop) + is_symbolic = any([symbolic.issymbolic(r) for r in (start, end)]) if self.begin: - # If begin, change initialization assignment and prepend states before - # guard - init_edges = [] - before_states = loop_struct[0] - for before_state in before_states: - init_edge = self.loop_init_edge() - init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2]) - init_edges.append(init_edge) - append_states = before_states - - # Add `count` states, each with instantiated iteration variable + # Create states for loop subgraph + peeled_iterations: List[ControlFlowRegion] = [] for i in range(self.count): - # Instantiate loop states with iterate value - state_name: str = 'start_' + itervar + str(i * rng[2]) - state_name = state_name.replace('-', 'm').replace('+', 'p').replace('*', 'M').replace('/', 'D') - new_states = self.instantiate_loop( - sdfg, - loop_states, - loop_subgraph, - itervar, - rng[0] + i * rng[2], - state_name, - ) - - # Connect states to before the loop with unconditional edges - for append_state in append_states: - graph.add_edge(append_state, new_states[first_id], sd.InterstateEdge()) - append_states = [new_states[last_id]] - - # Reconnect edge to guard state from last peeled iteration - for append_state in append_states: - if append_state not in before_states: - for init_edge in init_edges: - graph.remove_edge(init_edge) - graph.add_edge(append_state, init_edge.dst, init_edges[0].data) + # Instantiate loop contents as a new control flow region with iterate value. + current_index = start + (i * stride) + iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, + str(i) if is_symbolic else None) + + # Connect iterations with unconditional edges + if len(peeled_iterations) > 0: + graph.add_edge(peeled_iterations[-1], iteration_region, sd.InterstateEdge()) + peeled_iterations.append(iteration_region) + + # Connect the peeled iterations to the remainder of the loop and adjust the remaining iteration bounds. + if peeled_iterations: + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, peeled_iterations[0], ie.data) + graph.remove_edge(ie) + graph.add_edge(peeled_iterations[-1], self.loop, sd.InterstateEdge()) + + new_start = symbolic.evaluate(start + (self.count * stride), sdfg.constants) + self.loop.init_statement = CodeBlock(f'{self.loop.loop_variable} = {new_start}') else: - # If begin, change initialization assignment and prepend states before - # guard - itervar_sym = pystr_to_symbolic(itervar) - condition_edge.data.condition = CodeBlock(self._modify_cond(condition_edge.data.condition, itervar, rng[2])) - not_condition_edge.data.condition = CodeBlock( - self._modify_cond(not_condition_edge.data.condition, itervar, rng[2])) - prepend_state = after_state - - # Add `count` states, each with instantiated iteration variable + # Create states for loop subgraph + peeled_iterations: List[ControlFlowRegion] = [] for i in reversed(range(self.count)): - # Instantiate loop states with iterate value - state_name: str = 'end_' + itervar + str(-i * rng[2]) - state_name = state_name.replace('-', 'm').replace('+', 'p').replace('*', 'M').replace('/', 'D') - new_states = self.instantiate_loop( - sdfg, - loop_states, - loop_subgraph, - itervar, - itervar_sym + i * rng[2], - state_name, - ) - - # Connect states to before the loop with unconditional edges - graph.add_edge(new_states[last_id], prepend_state, sd.InterstateEdge()) - prepend_state = new_states[first_id] - - # Reconnect edge to guard state from last peeled iteration - if prepend_state != after_state: - graph.remove_edge(not_condition_edge) - graph.add_edge(not_condition_edge.src, prepend_state, not_condition_edge.data) + # Instantiate loop contents as a new control flow region with iterate value. + current_index = pystr_to_symbolic(self.loop.loop_variable) + (i * stride) + iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, + str(i) if is_symbolic else None) + + # Connect iterations with unconditional edges + if len(peeled_iterations) > 0: + graph.add_edge(iteration_region, peeled_iterations[-1], sd.InterstateEdge()) + peeled_iterations.append(iteration_region) + + # Connect the peeled iterations to the remainder of the loop and adjust the remaining iteration bounds. + if peeled_iterations: + for oe in graph.out_edges(self.loop): + graph.add_edge(peeled_iterations[0], oe.dst, oe.data) + graph.remove_edge(oe) + graph.add_edge(self.loop, peeled_iterations[-1], sd.InterstateEdge()) + + new_cond = CodeBlock(self._modify_cond(self.loop.loop_condition, self.loop.loop_variable, stride)) + self.loop.loop_condition = new_cond diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index ff35b9f4a5..66dbca1a83 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -5,7 +5,7 @@ from typing import List, Optional from dace import sdfg as sd, symbolic -from dace.properties import CodeBlock, Property, make_properties +from dace.properties import Property, make_properties from dace.sdfg import utils as sdutil from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.frontend.python.astutils import ASTFindReplace @@ -15,7 +15,7 @@ @make_properties @xf.experimental_cfg_block_compatible class LoopUnroll(xf.MultiStateTransformation): - """ Unrolls a state machine for-loop into multiple states """ + """ Unrolls a for-loop into multiple individual control flow regions """ loop = xf.PatternNode(LoopRegion) @@ -50,6 +50,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True def apply(self, graph: ControlFlowRegion, sdfg): + # Loop must be fully unrollable for now. + if self.count != 0: + raise NotImplementedError # TODO(later) + # Obtain loop information start = loop_analysis.get_init_assignment(self.loop) end = loop_analysis.get_loop_end(self.loop) @@ -65,9 +69,6 @@ def apply(self, graph: ControlFlowRegion, sdfg): # Create states for loop subgraph unrolled_iterations: List[ControlFlowRegion] = [] for i in range(0, loop_diff, stride): - if self.count != 0 and i >= self.count: - break - # Instantiate loop contents as a new control flow region with iterate value. current_index = start + i iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, @@ -78,27 +79,14 @@ def apply(self, graph: ControlFlowRegion, sdfg): graph.add_edge(unrolled_iterations[-1], iteration_region, sd.InterstateEdge()) unrolled_iterations.append(iteration_region) - if self.count != 0: - # Not all iterations of the loop were unrolled. Connect the unrolled iterations accordingly and adjust the - # remaining loop bounds. - if unrolled_iterations: - for ie in graph.in_edges(self.loop): - graph.add_edge(ie.src, unrolled_iterations[0], ie.data) - graph.add_edge(unrolled_iterations[-1], self.loop, sd.InterstateEdge()) - - new_start = symbolic.evaluate(start + (self.count * stride), sdfg.constants) - self.loop.init_statement = CodeBlock(f'{self.loop.loop_variable} = {new_start}') - else: - # Everything was unrolled. - # Connect the unrolled iterations to the rest of the graph. - if unrolled_iterations: - for ie in graph.in_edges(self.loop): - graph.add_edge(ie.src, unrolled_iterations[0], ie.data) - for oe in graph.out_edges(self.loop): - graph.add_edge(unrolled_iterations[-1], oe.dst, oe.data) - - # Remove old loop. - graph.remove_node(self.loop) + if unrolled_iterations: + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, unrolled_iterations[0], ie.data) + for oe in graph.out_edges(self.loop): + graph.add_edge(unrolled_iterations[-1], oe.dst, oe.data) + + # Remove old loop. + graph.remove_node(self.loop) if self.inline_iterations: for it in unrolled_iterations: diff --git a/tests/transformations/loop_manipulation_test.py b/tests/transformations/loop_manipulation_test.py index 292a451510..9a3abc0239 100644 --- a/tests/transformations/loop_manipulation_test.py +++ b/tests/transformations/loop_manipulation_test.py @@ -47,10 +47,8 @@ def test_unroll(): def test_peeling_start(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 2 sdfg.apply_transformations(LoopPeeling, dict(count=2)) - assert len(sdfg.nodes()) == 6 - sdfg.simplify() assert len(sdfg.nodes()) == 4 A = np.random.rand(20) B = np.random.rand(20) @@ -67,10 +65,8 @@ def test_peeling_start(): def test_peeling_end(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 2 sdfg.apply_transformations(LoopPeeling, dict(count=2, begin=False)) - assert len(sdfg.nodes()) == 6 - sdfg.simplify() assert len(sdfg.nodes()) == 4 A = np.random.rand(20) B = np.random.rand(20) From 145c0ea5bfb9647d95625b3612f5da65ac312889 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 11 Oct 2024 08:37:57 +0200 Subject: [PATCH 029/108] Added tests for the `_read_and_write_sets()`. --- tests/sdfg/state_test.py | 90 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 7ba43ac4c0..1fac1d56f4 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace import subsets as sbs from dace.transformation.helpers import find_sdfg_control_flow @@ -43,6 +44,7 @@ def test_read_write_set_y_formation(): assert 'B' not in state.read_and_write_sets()[0] + def test_deepcopy_state(): N = dace.symbol('N') @@ -58,6 +60,92 @@ def double_loop(arr: dace.float32[N]): sdfg.validate() +def test_read_and_write_set_filter(): + sdfg = dace.SDFG('graph') + state = sdfg.add_state('state') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_scalar('B', dace.float64) + sdfg.add_array('C', [2, 2], dace.float64) + A, B, C = (state.add_access(name) for name in ('A', 'B', 'C')) + + state.add_nedge( + A, + B, + dace.Memlet("B[0] -> 0, 0"), + ) + state.add_nedge( + B, + C, + # If the Memlet would be `B[0] -> 1, 1` it would then be filtered out. + # This is an intentional behaviour for compatibility. + dace.Memlet("C[1, 1] -> 0"), + ) + state.add_nedge( + B, + C, + dace.Memlet("B[0] -> 0, 0"), + ) + sdfg.validate() + + expected_reads = { + "A": [sbs.Range.from_string("0, 0")], + # See comment in `state._read_and_write_sets()` why "B" is here + # it should actually not, but it is a bug. + "B": [sbs.Range.from_string("0")], + } + expected_writes = { + # However, this should always be here. + "B": [sbs.Range.from_string("0")], + "C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + +def test_read_and_write_set_selection(): + sdfg = dace.SDFG('graph') + state = sdfg.add_state('state') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_scalar('B', dace.float64) + A, B = (state.add_access(name) for name in ('A', 'B')) + + state.add_nedge( + A, + B, + dace.Memlet("A[0, 0]"), + ) + sdfg.validate() + + expected_reads = { + "A": [sbs.Range.from_string("0, 0")], + } + expected_writes = { + "B": [sbs.Range.from_string("0")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + def test_add_mapped_tasklet(): sdfg = dace.SDFG("test_add_mapped_tasklet") state = sdfg.add_state(is_start_block=True) @@ -82,6 +170,8 @@ def test_add_mapped_tasklet(): if __name__ == '__main__': + test_read_and_write_set_selection() + test_read_and_write_set_filter() test_read_write_set() test_read_write_set_y_formation() test_deepcopy_state() From 38e748bfb2576103c0cf490fcb3f77a9c20361d9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 11 Oct 2024 08:49:42 +0200 Subject: [PATCH 030/108] Added the fix from my MapFusion PR. --- dace/sdfg/state.py | 111 +++++++++++++++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 33 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 8d443e6beb..a9c99da7f9 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -745,51 +745,96 @@ def update_if_not_none(dic, update): return defined_syms + def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, List[Subset]]]: """ Determines what data is read and written in this subgraph, returning dictionaries from data containers to all subsets that are read/written. """ + from dace.sdfg import utils # Avoid cyclic import + + # Ensures that the `{src,dst}_subset` are properly set. + # TODO: find where the problems are + for edge in self.edges(): + edge.data.try_initialize(self.sdfg, self, edge) + read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) - from dace.sdfg import utils # Avoid cyclic import - subgraphs = utils.concurrent_subgraphs(self) - for sg in subgraphs: - rs = collections.defaultdict(list) - ws = collections.defaultdict(list) + + for subgraph in utils.concurrent_subgraphs(self): + subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph. + subgraph_write_set = collections.defaultdict(list) # Traverse in topological order, so data that is written before it # is read is not counted in the read set - for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): - if isinstance(n, nd.AccessNode): - in_edges = sg.in_edges(n) - out_edges = sg.out_edges(n) - # Filter out memlets which go out but the same data is written to the AccessNode by another memlet - for out_edge in list(out_edges): - for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data - and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): - out_edges.remove(out_edge) - break - - for e in in_edges: - # skip empty memlets - if e.data.is_empty(): - continue - # Store all subsets that have been written - ws[n.data].append(e.data.subset) - for e in out_edges: - # skip empty memlets - if e.data.is_empty(): + # TODO: This only works if every data descriptor is only once in a path. + for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()): + if not isinstance(n, nd.AccessNode): + # Read and writes can only be done through access nodes, + # so ignore every other node. + continue + + # Get a list of all incoming (writes) and outgoing (reads) edges of the + # access node, ignore all empty memlets as they do not carry any data. + in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()] + out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()] + + # Extract the subsets that describes where we read and write the data + # and store them for the later filtering. + # NOTE: In certain cases the corresponding subset might be None, in this case + # we assume that the whole array is written, which is the default behaviour. + ac_desc = n.desc(self.sdfg) + ac_size = ac_desc.total_size + in_subsets = dict() + for in_edge in in_edges: + # Ensure that if the destination subset is not given, our assumption, that the + # whole array is written to, is valid, by testing if the memlet transfers the + # whole array. + assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size) + in_subsets[in_edge] = ( + sbs.Range.from_array(ac_desc) + if in_edge.data.dst_subset is None + else in_edge.data.dst_subset + ) + out_subsets = dict() + for out_edge in out_edges: + assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size) + out_subsets[out_edge] = ( + sbs.Range.from_array(ac_desc) + if out_edge.data.src_subset is None + else out_edge.data.src_subset + ) + + # If a memlet reads a particular region of data from the access node and there + # exists a memlet at the same access node that writes to the same region, then + # this read is ignored, and not included in the final read set, but only + # accounted fro in the write set. See also note below. + # TODO: Handle the case when multiple disjoint writes are needed to cover the read. + for out_edge in list(out_edges): + for in_edge in in_edges: + if out_edge.data.data != in_edge.data.data: + # NOTE: This check does not make any sense, and is in my (@philip-paul-mueller) + # view wrong. As it will filter out some accesses but not all, which one solely + # depends on how the memelts were created, i.e. to which container their `data` + # attribute is associated to. See also [issue #1643](https://github.com/spcl/dace/issues/1643). continue - rs[n.data].append(e.data.subset) - # Union all subgraphs, so an array that was excluded from the read - # set because it was written first is still included if it is read - # in another subgraph - for data, accesses in rs.items(): + if in_subsets[in_edge].covers(out_subsets[out_edge]): + out_edges.remove(out_edge) + break + + # Update the read and write sets of the subgraph. + if in_edges: + subgraph_write_set[n.data].extend(in_subsets.values()) + if out_edges: + subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) + + # Add the subgraph's read and write set to the final ones. + for data, accesses in subgraph_read_set.items(): read_set[data] += accesses - for data, accesses in ws.items(): + for data, accesses in subgraph_write_set.items(): write_set[data] += accesses - return read_set, write_set + + return copy.deepcopy((read_set, write_set)) + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ From 3748c03cd863c3f5241ab5e0e03164d3c6c61319 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 11 Oct 2024 08:55:51 +0200 Subject: [PATCH 031/108] Now made `read_and_write_sets()` fully adhere to their own definition. --- dace/sdfg/state.py | 6 ------ tests/sdfg/state_test.py | 6 ------ 2 files changed, 12 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a9c99da7f9..ba853d088d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -811,12 +811,6 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # TODO: Handle the case when multiple disjoint writes are needed to cover the read. for out_edge in list(out_edges): for in_edge in in_edges: - if out_edge.data.data != in_edge.data.data: - # NOTE: This check does not make any sense, and is in my (@philip-paul-mueller) - # view wrong. As it will filter out some accesses but not all, which one solely - # depends on how the memelts were created, i.e. to which container their `data` - # attribute is associated to. See also [issue #1643](https://github.com/spcl/dace/issues/1643). - continue if in_subsets[in_edge].covers(out_subsets[out_edge]): out_edges.remove(out_edge) break diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 1fac1d56f4..7d74ae6bcc 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -76,8 +76,6 @@ def test_read_and_write_set_filter(): state.add_nedge( B, C, - # If the Memlet would be `B[0] -> 1, 1` it would then be filtered out. - # This is an intentional behaviour for compatibility. dace.Memlet("C[1, 1] -> 0"), ) state.add_nedge( @@ -89,12 +87,8 @@ def test_read_and_write_set_filter(): expected_reads = { "A": [sbs.Range.from_string("0, 0")], - # See comment in `state._read_and_write_sets()` why "B" is here - # it should actually not, but it is a bug. - "B": [sbs.Range.from_string("0")], } expected_writes = { - # However, this should always be here. "B": [sbs.Range.from_string("0")], "C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")], } From 6d83976df49df126704396d0207df936cf3af2e4 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 11 Oct 2024 09:24:25 +0200 Subject: [PATCH 032/108] Update refine nested access --- dace/sdfg/state.py | 38 +++--- .../interstate/move_loop_into_map.py | 111 +++++++----------- .../transformation/interstate/sdfg_nesting.py | 13 +- dace/transformation/pass_pipeline.py | 49 ++++++++ .../refine_nested_access_test.py | 2 +- 5 files changed, 124 insertions(+), 89 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index abd325da53..1f80e1c2f8 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -23,7 +23,7 @@ from dace.properties import (CodeBlock, DebugInfoProperty, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import (MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge, +from dace.sdfg.graph import (MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge, generate_element_id) from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state @@ -2522,7 +2522,7 @@ def root_sdfg(self) -> 'SDFG': raise RuntimeError('Root CFG is not of type SDFG') return self.cfg_list[0] - def reset_cfg_list(self) -> List['ControlFlowRegion']: + def reset_cfg_list(self) -> List['AbstractControlFlowRegion']: """ Reset the CFG list when changes have been made to the SDFG's CFG tree. This collects all control flow graphs recursively and propagates the collection to all CFGs as the new CFG list. @@ -2766,7 +2766,7 @@ def add_state_after(self, ################################################################### # Traversal methods - def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']: + def all_control_flow_regions(self, recursive=False) -> Iterator['AbstractControlFlowRegion']: """ Iterate over this and all nested control flow regions. """ yield self for block in self.nodes(): @@ -2774,11 +2774,8 @@ def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegi for node in block.nodes(): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_control_flow_regions(recursive=recursive) - elif isinstance(block, ControlFlowRegion): + elif isinstance(block, AbstractControlFlowRegion): yield from block.all_control_flow_regions(recursive=recursive) - elif isinstance(block, ConditionalBlock): - for _, branch in block.branches: - yield from branch.all_control_flow_regions(recursive=recursive) def all_sdfgs_recursive(self) -> Iterator['SDFG']: """ Iterate over this and all nested SDFGs. """ @@ -2791,11 +2788,8 @@ def all_states(self) -> Iterator[SDFGState]: for block in self.nodes(): if isinstance(block, SDFGState): yield block - elif isinstance(block, ControlFlowRegion): + elif isinstance(block, AbstractControlFlowRegion): yield from block.all_states() - elif isinstance(block, ConditionalBlock): - for _, region in block.branches: - yield from region.all_states() def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: """ Iterate over all control flow blocks in this control flow graph. """ @@ -3341,9 +3335,9 @@ def from_json(cls, json_obj, context=None): for condition, region in json_obj['branches']: if condition is not None: - ret._branches.append((CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))) + ret.add_branch(CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context)) else: - ret._branches.append((None, ControlFlowRegion.from_json(region, context))) + ret.add_branch(None, ControlFlowRegion.from_json(region, context)) return ret def inline(self) -> Tuple[bool, Any]: @@ -3403,10 +3397,26 @@ def inline(self) -> Tuple[bool, Any]: return True, (guard_state, end_state) + # Abstract control flow region overrides + + @property + def start_block(self): + return None + + @start_block.setter + def start_block(self, _): + pass + # Graph API overrides. + def node_id(self, node: 'ControlFlowBlock') -> int: + try: + return next(i for i, (_, b) in enumerate(self._branches) if b is node) + except StopIteration: + raise NodeNotFoundError(node) + def nodes(self) -> List['ControlFlowBlock']: - return [node for _, node in self._branches if node is not None] + return [node for _, node in self._branches] def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: return [] diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index 29a9906fe0..5017ba3bea 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -1,18 +1,19 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Moves a loop around a map into the map """ import copy +from dace.sdfg.state import ControlFlowRegion, LoopRegion, SDFGState import dace.transformation.helpers as helpers import networkx as nx from dace.sdfg.scope import ScopeTree -from dace import data as dt, Memlet, nodes, sdfg as sd, subsets as sbs, symbolic, symbol -from dace.properties import CodeBlock -from dace.sdfg import nodes, propagation +from dace import Memlet, nodes, sdfg as sd, subsets as sbs, symbolic, symbol +from dace.sdfg import nodes, propagation, utils as sdutil from dace.transformation import transformation -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from sympy import diff from typing import List, Set, Tuple +from dace.transformation.passes.analysis import loop_analysis + def fold(memlet_subset_ranges, itervar, lower, upper): return [(r[0].replace(symbol(itervar), lower), r[1].replace(symbol(itervar), upper), r[2]) @@ -23,32 +24,34 @@ def offset(memlet_subset_ranges, value): return (memlet_subset_ranges[0] + value, memlet_subset_ranges[1] + value, memlet_subset_ranges[2]) -@transformation.single_level_sdfg_only -class MoveLoopIntoMap(DetectLoop, transformation.MultiStateTransformation): +@transformation.experimental_cfg_block_compatible +class MoveLoopIntoMap(transformation.MultiStateTransformation): """ Moves a loop around a map into the map """ - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + loop = transformation.PatternNode(LoopRegion) - # Obtain loop information - body: sd.SDFGState = self.loop_begin + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # Obtain iteration variable, range, and stride - loop_info = self.loop_information() - if not loop_info: + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # If loop information cannot be determined, fail. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) + itervar = self.loop.loop_variable + if start is None or end is None or step is None or itervar is None: return False - itervar, (start, end, step), (_, body_end) = loop_info if step not in [-1, 1]: return False # Body must contain a single state - if body != body_end: + if len(self.loop.nodes()) != 1 or not isinstance(self.loop.nodes()[0], SDFGState): return False + body: SDFGState = self.loop.nodes()[0] # Body must have only a single connected component # NOTE: This is a strict check that can be potentially relaxed. @@ -153,14 +156,9 @@ def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool, return True - def apply(self, _, sdfg: sd.SDFG): - # Obtain loop information - body: sd.SDFGState = self.loop_begin - - # Obtain iteration variable, range, and stride - itervar, (start, end, step), _ = self.loop_information() - - forward_loop = step > 0 + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): + body: sd.SDFGState = self.loop.nodes()[0] + itervar = self.loop.loop_variable for node in body.nodes(): if isinstance(node, nodes.MapEntry): @@ -171,50 +169,27 @@ def apply(self, _, sdfg: sd.SDFG): # nest map's content in sdfg map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False) nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True) + nested_state: SDFGState = nsdfg.sdfg.nodes()[0] # replicate loop in nested sdfg - new_before, new_guard, new_after = nsdfg.sdfg.add_loop( - before_state=None, - loop_state=nsdfg.sdfg.nodes()[0], - loop_end_state=None, - after_state=None, - loop_var=itervar, - initialize_expr=f'{start}', - condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}', - increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}') - - # remove outer loop - before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0] - for e in nsdfg.sdfg.out_edges(new_guard): - if e.dst is new_after: - guard_after_edge = e - else: - guard_body_edge = e - - if self.expr_index <= 1: - guard = self.loop_guard - for body_inedge in sdfg.in_edges(body): - if body_inedge.src is guard: - guard_body_edge.data.assignments.update(body_inedge.data.assignments) - sdfg.remove_edge(body_inedge) - for body_outedge in sdfg.out_edges(body): - sdfg.remove_edge(body_outedge) - for guard_inedge in sdfg.in_edges(guard): - before_guard_edge.data.assignments.update(guard_inedge.data.assignments) - guard_inedge.data.assignments = {} - sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) - sdfg.remove_edge(guard_inedge) - for guard_outedge in sdfg.out_edges(guard): - if guard_outedge.dst is body: - guard_body_edge.data.assignments.update(guard_outedge.data.assignments) - else: - guard_after_edge.data.assignments.update(guard_outedge.data.assignments) - guard_outedge.data.condition = CodeBlock("1") - sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) - sdfg.remove_edge(guard_outedge) - sdfg.remove_node(guard) - else: # Rotated or self loops - raise NotImplementedError('MoveLoopIntoMap not implemented for rotated and self-loops') + inner_loop = LoopRegion(self.loop.label, + self.loop.loop_condition, + self.loop.loop_variable, + self.loop.init_statement, + self.loop.update_statement, + self.loop.inverted, + nsdfg, + self.loop.update_before_condition) + inner_loop.add_node(nested_state, is_start_block=True) + nsdfg.sdfg.remove_node(nested_state) + nsdfg.sdfg.add_node(inner_loop, is_start_block=True) + + graph.add_node(body, is_start_block=(graph.start_block is self.loop)) + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, body, ie.data) + for oe in graph.out_edges(self.loop): + graph.add_edge(body, oe.dst, oe.data) + graph.remove_node(self.loop) if itervar in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[itervar] diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 622dfe5595..2bb24ab4e6 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -878,7 +878,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class RefineNestedAccess(transformation.SingleStateTransformation): """ Reduces memlet shape when a memlet is connected to a nested SDFG, but not @@ -920,7 +920,7 @@ def _candidates( in_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {} out_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {} ignore = set() - for nstate in nsdfg.sdfg.nodes(): + for nstate in nsdfg.sdfg.states(): for dnode in nstate.data_nodes(): if nsdfg.sdfg.arrays[dnode.data].transient: continue @@ -967,7 +967,7 @@ def _candidates( in_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset)))) # Check read memlets in interstate edges for candidates - for e in nsdfg.sdfg.edges(): + for e in nsdfg.sdfg.all_interstate_edges(): for m in e.data.get_read_memlets(nsdfg.sdfg.arrays): # If more than one unique element detected, remove from candidates if m.data in in_candidates: @@ -1032,7 +1032,8 @@ def _check_cand(candidates, outer_edges): # If there are any symbols here that are not defined # in "defined_symbols" - missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys())) + missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), + list(indices)) - set(nsdfg.symbol_mapping.keys())) if missing_symbols: ignore.add(cname) continue @@ -1075,13 +1076,13 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], if aname in refined: continue # Refine internal memlets - for nstate in nsdfg.nodes(): + for nstate in nsdfg.states(): for e in nstate.edges(): if e.data.data == aname: e.data.subset.offset(refine.subset, True, indices) # Refine accesses in interstate edges refiner = ASTRefiner(aname, refine.subset, nsdfg, indices) - for isedge in nsdfg.edges(): + for isedge in nsdfg.all_interstate_edges(): for k, v in isedge.data.assignments.items(): vast = ast.parse(v) refiner.visit(vast) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 8c748ef8d5..b0513f6125 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -10,6 +10,8 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Type, Union from dataclasses import dataclass +from dace.sdfg.state import ControlFlowRegion + class Modifies(Flag): """ @@ -257,6 +259,53 @@ def apply(self, state: SDFGState, pipeline_results: Dict[str, Any]) -> Optional[ raise NotImplementedError +@properties.make_properties +class ControlFlowRegionPass(Pass): + """ + A specialized Pass type that applies to each control flow region separately, buttom up. Such a pass is realized by + implementing the ``apply`` method, which accepts a single control flow region, and assumes the pass was already + applied to each control flow region nested inside of that. + + :see: Pass + """ + + CATEGORY: str = 'Helper' + + def apply_pass(self, sdfg: SDFG, + pipeline_results: Dict[str, Any]) -> Optional[Dict[ControlFlowRegion, Optional[Any]]]: + """ + Applies the pass to control flow regions of the given SDFG by calling ``apply`` on each region. + + :param sdfg: The SDFG to apply the pass to. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: A dictionary of ``{region: return value}`` for visited regions with a non-None return value, or None + if nothing was returned. + """ + result = {} + for region in sdfg.all_control_flow_regions(recursive=True): + retval = self.apply(region, pipeline_results) + if retval is not None: + result[region] = retval + + if not result: + return None + return result + + def apply(self, region: ControlFlowRegion, pipeline_results: Dict[str, Any]) -> Optional[Any]: + """ + Applies this pass on the given control flow region. + + :param state: The control flow region to apply the pass to. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: Some object if pass was applied, or None if nothing changed. + """ + raise NotImplementedError + + @properties.make_properties class ScopePass(Pass): """ diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index d9fb9a7392..4c33ece899 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests for the RefineNestedAccess transformation. """ import dace import numpy as np From bb4e9c87ffaf0c7734946664c552c4fe22c82b8a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 11 Oct 2024 10:18:46 +0200 Subject: [PATCH 033/108] Fixes --- dace/sdfg/state.py | 5 ++++- dace/transformation/interstate/loop_to_map.py | 20 ++++++++++++++----- .../passes/dead_state_elimination.py | 3 ++- .../simplification/control_flow_raising.py | 3 +++ tests/transformations/loop_detection_test.py | 3 ++- 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1f80e1c2f8..11fd9c980a 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -8,7 +8,7 @@ import inspect import itertools import warnings -from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, +from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union, overload) import dace @@ -3418,6 +3418,9 @@ def node_id(self, node: 'ControlFlowBlock') -> int: def nodes(self) -> List['ControlFlowBlock']: return [node for _, node in self._branches] + def number_of_nodes(self): + return len(self._branches) + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: return [] diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index fed3118481..496adc238f 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -102,10 +102,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): _, write_set = self.loop.read_and_write_sets() loop_states = set(self.loop.all_states()) all_loop_blocks = set(self.loop.all_control_flow_blocks()) - #write_set: Set[str] = set() - #for block in loop_states: - # _, wset = block.read_and_write_sets() - # write_set |= wset # Collect symbol reads and writes from inter-state assignments in_order_loop_blocks = list(cfg_analysis.blockorder_topological_sort(self.loop, recursive=True, @@ -200,6 +196,20 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): # reassigned. in_order_blocks = list(cfg_analysis.blockorder_topological_sort(sdfg, recursive=True, ignore_nonstate_blocks=False)) + # First check the outgoing edges of the loop itself. + reassigned_symbols: Set[str] = None + for oe in graph.out_edges(self.loop): + if symbols_that_may_be_used & oe.data.read_symbols(): + return False + # Check for symbols that are set by all outgoing edges + # TODO: Handle case of subset of out_edges + if reassigned_symbols is None: + reassigned_symbols = set(oe.data.assignments.keys()) + else: + reassigned_symbols &= oe.data.assignments.keys() + # Remove reassigned symbols + if reassigned_symbols is not None: + symbols_that_may_be_used -= reassigned_symbols loop_idx = in_order_blocks.index(self.loop) for block in in_order_blocks[loop_idx + 1:]: if block in all_loop_blocks: @@ -213,7 +223,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): return False # Check inter-state edges - reassigned_symbols: Set[str] = None + reassigned_symbols = None for e in block.parent_graph.out_edges(block): if symbols_that_may_be_used & e.data.read_symbols(): return False diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 53e9b4f466..6a7e80fabf 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -41,8 +41,9 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters """ result: Set[Union[ControlFlowBlock, InterstateEdge]] = set() removed_regions: Set[ControlFlowRegion] = set() + annotated = None for cfg in list(sdfg.all_control_flow_regions()): - if cfg in removed_regions: + if cfg in removed_regions or isinstance(cfg, ConditionalBlock): continue # Mark dead blocks and remove them diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index c6aced2a6d..30f5595594 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -28,6 +28,9 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) for region in cfgs: + if isinstance(region, ConditionalBlock): + continue + dummy_exit = region.add_state('__DACE_DUMMY') for s in region.sink_nodes(): if s is not dummy_exit: diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 323a27787a..b7c1056162 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -19,7 +19,8 @@ def tester(a: dace.float64[20]): for i in range(1, 20): a[i] = a[i - 1] + 1 - sdfg = tester.to_sdfg() + tester.use_experimental_cfg_blocks = False + sdfg = tester.to_sdfg(simplify=False) xform = CountLoops() assert sdfg.apply_transformations(xform) == 1 itvar, rng, _ = xform.loop_information() From bd76961cb0e1d8107c0d546393d7a7f68852e890 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 11 Oct 2024 10:48:33 +0200 Subject: [PATCH 034/108] Adjust SDFG nesting --- dace/transformation/helpers.py | 2 +- .../interstate/gpu_transform_sdfg.py | 55 ++++++++++--------- .../transformation/interstate/sdfg_nesting.py | 23 ++++---- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 6ca4602079..0537cdba6f 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1070,7 +1070,7 @@ def constant_symbols(sdfg: SDFG) -> Set[str]: :param sdfg: The input SDFG. :return: A set of symbol names that remain constant throughout the SDFG. """ - interstate_symbols = {k for e in sdfg.edges() for k in e.data.assignments.keys()} + interstate_symbols = {k for e in sdfg.all_interstate_edges() for k in e.data.assignments.keys()} return set(sdfg.symbols) - interstate_symbols diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 844651b071..2753844fc1 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -1,15 +1,16 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains inter-state transformations of an SDFG to run on the GPU. """ -from dace import data, memlet, dtypes, registry, sdfg as sd, symbolic, subsets as sbs, propagate_memlets_sdfg +from dace import data, memlet, dtypes, sdfg as sd, subsets as sbs, propagate_memlets_sdfg from dace.sdfg import nodes, scope from dace.sdfg import utils as sdutil +from dace.sdfg.state import SDFGState from dace.transformation import transformation, helpers as xfh from dace.properties import Property, make_properties from collections import defaultdict from copy import deepcopy as dc from sympy import floor -from typing import Dict +from typing import Dict, List, Set, Tuple gpu_storage = [dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared, dtypes.StorageType.CPU_Pinned] @@ -83,7 +84,7 @@ def _recursive_in_check(node, state, gpu_scalars): @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class GPUTransformSDFG(transformation.MultiStateTransformation): """ Implements the GPUTransformSDFG transformation. @@ -144,7 +145,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)): return False - for state in sdfg.nodes(): + for state in sdfg.states(): schildren = state.scope_children() for node in schildren[None]: # If two top-level tasklets are connected with a code->code @@ -160,14 +161,14 @@ def apply(self, _, sdfg: sd.SDFG): # Step 0: SDFG metadata # Find all input and output data descriptors - input_nodes = [] - output_nodes = [] + input_nodes: List[Tuple[str, data.Data]] = [] + output_nodes: List[Tuple[str, data.Data]] = [] global_code_nodes: Dict[sd.SDFGState, nodes.Tasklet] = defaultdict(list) # Propagate memlets to ensure that we can find the true array subsets that are written. propagate_memlets_sdfg(sdfg) - for state in sdfg.nodes(): + for state in sdfg.states(): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): @@ -190,8 +191,8 @@ def apply(self, _, sdfg: sd.SDFG): if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append((e.data.data, sdfg.arrays[e.data.data])) - start_state = sdfg.start_state - end_states = sdfg.sink_nodes() + start_block = sdfg.start_block + end_blocks = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals @@ -230,7 +231,7 @@ def apply(self, _, sdfg: sd.SDFG): found_full_write = False full_subset = sbs.Range.from_array(onode) try: - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == onodename): for e in state.in_edges(node): @@ -251,20 +252,20 @@ def apply(self, _, sdfg: sd.SDFG): if not found_full_write: input_nodes.append((onodename, onode)) - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): memlets = edge.data.get_read_memlets(sdfg.arrays) for mem in memlets: if sdfg.arrays[mem.data].storage == dtypes.StorageType.GPU_Global: data_already_on_gpu[mem.data] = None # Replace nodes - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets - for state in sdfg.nodes(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] @@ -274,7 +275,7 @@ def apply(self, _, sdfg: sd.SDFG): excluded_copyin = self.exclude_copyin.split(',') copyin_state = sdfg.add_state(sdfg.label + '_copyin') - sdfg.add_edge(copyin_state, start_state, sd.InterstateEdge()) + sdfg.add_edge(copyin_state, start_block, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(input_nodes): if nname in excluded_copyin or nname not in cloned_arrays: @@ -290,7 +291,7 @@ def apply(self, _, sdfg: sd.SDFG): excluded_copyout = self.exclude_copyout.split(',') copyout_state = sdfg.add_state(sdfg.label + '_copyout') - for state in end_states: + for state in end_blocks: sdfg.add_edge(state, copyout_state, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(output_nodes): @@ -306,8 +307,8 @@ def apply(self, _, sdfg: sd.SDFG): ####################################################### # Step 4: Change all top-level maps and library nodes to GPU schedule - gpu_nodes = set() - for state in sdfg.nodes(): + gpu_nodes: Set[Tuple[SDFGState, nodes.Node]] = set() + for state in sdfg.states(): sdict = state.scope_dict() for node in state.nodes(): if sdict[node] is None: @@ -347,7 +348,7 @@ def apply(self, _, sdfg: sd.SDFG): # inside a GPU kernel. gpu_scalars = {} - nsdfgs = [] + nsdfgs: List[Tuple[nodes.NestedSDFG, SDFGState]] = [] changed = True # Iterates over Tasklets that not inside a GPU kernel. Such Tasklets must be moved inside a GPU kernel only # if they write to GPU memory. The check takes into account the fact that GPU kernels can read host-based @@ -406,7 +407,7 @@ def apply(self, _, sdfg: sd.SDFG): const_syms = xfh.constant_symbols(sdfg) - for state in sdfg.nodes(): + for state in sdfg.states(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: @@ -472,22 +473,22 @@ def apply(self, _, sdfg: sd.SDFG): cloned_data = set(cloned_arrays.keys()).union(gpu_scalars.keys()).union(data_already_on_gpu.keys()) - for state in list(sdfg.nodes()): + for state in list(sdfg.states()): arrays_used = set() - for e in sdfg.out_edges(state): + for e in state.parent_graph.out_edges(state): # Used arrays = intersection between symbols and cloned data arrays_used.update(set(e.data.free_symbols) & cloned_data) # Create a state and copy out used arrays if len(arrays_used) > 0: - co_state = sdfg.add_state(state.label + '_icopyout') + co_state = state.parent_graph.add_state(state.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state - for e in sdfg.out_edges(state): - sdutil.change_edge_src(sdfg, state, co_state) + for e in state.parent_graph.out_edges(state): + sdutil.change_edge_src(state.parent_graph, state, co_state) # Add unconditional edge to interim state - sdfg.add_edge(state, co_state, sd.InterstateEdge()) + state.parent_graph.add_edge(state, co_state, sd.InterstateEdge()) # Add copy-out nodes for nname in arrays_used: @@ -526,7 +527,7 @@ def apply(self, _, sdfg: sd.SDFG): co_state.add_node(dst_array) co_state.add_nedge(src_array, dst_array, memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) - for e in sdfg.out_edges(co_state): + for e in state.parent_graph.out_edges(co_state): e.data.replace(devicename, hostname, False) # Step 9: Simplify diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 2bb24ab4e6..ff096a3198 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ SDFG nesting transformation. """ import ast @@ -22,7 +22,7 @@ @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class InlineSDFG(transformation.SingleStateTransformation): """ Inlines a single-state nested SDFG into a top-level SDFG. @@ -99,7 +99,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): nested_sdfg = self.nested_sdfg if nested_sdfg.no_inline: return False - if len(nested_sdfg.sdfg.nodes()) != 1: + if len(nested_sdfg.sdfg.nodes()) != 1 or not isinstance(nested_sdfg.sdfg.nodes()[0], SDFGState): return False # Ensure every connector has one incoming/outgoing edge and that it @@ -154,7 +154,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): out_data[dst.data] = e.src_conn rem_inpconns = dc(in_connectors) rem_outconns = dc(out_connectors) - nstate = nested_sdfg.sdfg.node(0) + nstate: SDFGState = nested_sdfg.sdfg.nodes()[0] for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): if node.data in rem_inpconns: @@ -314,7 +314,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped - reshapes: Set(str) = set() + reshapes: Set[str] = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue @@ -734,11 +734,10 @@ def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str], new @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class InlineTransients(transformation.SingleStateTransformation): """ - Inlines all transient arrays that are not used anywhere else into a - nested SDFG. + Inlines all transient arrays that are not used anywhere else into a nested SDFG. """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @@ -781,7 +780,7 @@ def _candidates(sdfg: SDFG, graph: SDFGState, nsdfg: nodes.NestedSDFG) -> Dict[s return candidates # Check for uses in other states - for state in sdfg.nodes(): + for state in sdfg.states(): if state is graph: continue for node in state.data_nodes(): @@ -1103,7 +1102,7 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class NestSDFG(transformation.MultiStateTransformation): """ Implements SDFG Nesting, taking an SDFG as an input and creating a nested SDFG node from it. """ @@ -1133,7 +1132,7 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: outputs = {} transients = {} - for state in nested_sdfg.nodes(): + for state in nested_sdfg.states(): # Input and output nodes are added as input and output nodes of the nested SDFG for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and not node.desc(nested_sdfg).transient): @@ -1254,7 +1253,7 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: nested_sdfg.arrays[newarrname].transient = False # Update memlets - for state in nested_sdfg.nodes(): + for state in nested_sdfg.states(): for _, edge in enumerate(state.edges()): _, _, _, _, mem = edge src = state.memlet_path(edge)[0].src From 3ab4bf3e379ea96fbb7635904cc8b320215fc558 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 11 Oct 2024 11:27:47 +0200 Subject: [PATCH 035/108] Updated a test for the `PruneConnectors` transformation. Before in this case the transformation was not able to be applied. The reason was because of the behaviour of the `SDFGState.read_and_write_sets()` function. However, now with the fix of the [PR#1678](https://github.com/spcl/dace/pull/1678) the transformation became applicable. --- tests/transformations/prune_connectors_test.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 4026ec3e1c..8ec0d3615a 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -153,7 +153,6 @@ def _make_read_write_sdfg( Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B` will either be associated to `inner_A` (`True`) or `inner_B` (`False`). - This choice has consequences on if the transformation can apply or not. Notes: This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643), @@ -421,7 +420,6 @@ def test_prune_connectors_with_dependencies(): def test_read_write_1(): - # Because the memlet is conforming, we can apply the transformation. sdfg, nsdfg = _make_read_write_sdfg(True) assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) @@ -429,10 +427,13 @@ def test_read_write_1(): def test_read_write_2(): - # Because the memlet is not conforming, we can not apply the transformation. + # In previous versions of DaCe the transformation could not be applied in the + # case of a not conforming Memlet. + # See [PR#1678](https://github.com/spcl/dace/pull/1678) sdfg, nsdfg = _make_read_write_sdfg(False) - assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) + assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) + sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True) if __name__ == "__main__": From 26e4ff0f6a226b14207fd780bdd8663db269f693 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 11 Oct 2024 12:39:15 +0200 Subject: [PATCH 036/108] Adapt more passes and add conditional pruning pass --- dace/transformation/helpers.py | 6 +- .../interstate/fpga_transform_sdfg.py | 15 +-- .../interstate/fpga_transform_state.py | 24 ++-- .../interstate/move_assignment_outside_if.py | 60 +++++----- .../simplification/control_flow_raising.py | 2 +- .../prune_empty_conditional_branches.py | 62 +++++++++++ dace/transformation/passes/simplify.py | 2 + .../prune_empty_conditional_branches_test.py | 105 ++++++++++++++++++ tests/transformations/gpu_transform_test.py | 2 +- .../move_assignment_outside_if_test.py | 67 ++++++----- 10 files changed, 258 insertions(+), 87 deletions(-) create mode 100644 dace/transformation/passes/simplification/prune_empty_conditional_branches.py create mode 100644 tests/passes/simplification/prune_empty_conditional_branches_test.py diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 0537cdba6f..da45c0bf04 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1419,8 +1419,8 @@ def can_run_state_on_fpga(state: SDFGState): return False # Streams have strict conditions due to code generator limitations - if (isinstance(node, nodes.AccessNode) and isinstance(graph.parent.arrays[node.data], data.Stream)): - nodedesc = graph.parent.arrays[node.data] + if (isinstance(node, nodes.AccessNode) and isinstance(graph.sdfg.arrays[node.data], data.Stream)): + nodedesc = graph.sdfg.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal @@ -1432,7 +1432,7 @@ def can_run_state_on_fpga(state: SDFGState): return False # Arrays of streams cannot have symbolic size on FPGA - if symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): + if symbolic.issymbolic(nodedesc.total_size, graph.sdfg.constants): return False # Streams cannot be unbounded on FPGA diff --git a/dace/transformation/interstate/fpga_transform_sdfg.py b/dace/transformation/interstate/fpga_transform_sdfg.py index ac4672d892..5c2acf1d64 100644 --- a/dace/transformation/interstate/fpga_transform_sdfg.py +++ b/dace/transformation/interstate/fpga_transform_sdfg.py @@ -1,14 +1,15 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains inter-state transformations of an SDFG to run on an FPGA. """ import networkx as nx from dace import properties +from dace.sdfg.sdfg import SDFG from dace.transformation import transformation @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class FPGATransformSDFG(transformation.MultiStateTransformation): """ Implements the FPGATransformSDFG transformation, which takes an entire SDFG and transforms it into an FPGA-capable SDFG. """ @@ -28,20 +29,20 @@ def expressions(cls): # Match anything return [nx.DiGraph()] - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): # Avoid import loops from dace.transformation.interstate import FPGATransformState # Condition match depends on matching FPGATransformState for each state - for state_id, state in enumerate(sdfg.nodes()): + for state in sdfg.states(): fps = FPGATransformState() - fps.setup_match(sdfg, graph.cfg_id, -1, {FPGATransformState.state: state_id}, 0) - if not fps.can_be_applied(sdfg, expr_index, sdfg): + fps.setup_match(sdfg, state.parent_graph.cfg_id, -1, {FPGATransformState.state: state.block_id}, 0) + if not fps.can_be_applied(state.parent_graph, expr_index, sdfg): return False return True - def apply(self, _, sdfg): + def apply(self, _, sdfg: SDFG): # Avoid import loops from dace.transformation.interstate import NestSDFG from dace.transformation.interstate import FPGATransformState diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index 60a2a33001..47ba478341 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -1,15 +1,17 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains inter-state transformations of an SDFG to run on an FPGA. """ import copy import dace -from dace import data, memlet, dtypes, registry, sdfg as sd, subsets +from dace import memlet, dtypes, sdfg as sd, subsets from dace.sdfg import nodes from dace.sdfg import utils as sdutil +from dace.sdfg.sdfg import SDFG +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.transformation import transformation, helpers as xfh -def fpga_update(sdfg, state, depth): +def fpga_update(sdfg: SDFG, state: SDFGState, depth: int): scope_dict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage == dtypes.StorageType.Default): @@ -29,7 +31,7 @@ def fpga_update(sdfg, state, depth): fpga_update(node.sdfg, s, depth + 1) -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class FPGATransformState(transformation.MultiStateTransformation): """ Implements the FPGATransformState transformation. """ @@ -76,7 +78,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA @@ -158,9 +160,9 @@ def apply(self, _, sdfg): sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) - sdfg.add_node(pre_state) - sdutil.change_edge_dest(sdfg, state, pre_state) - sdfg.add_edge(pre_state, state, sd.InterstateEdge()) + graph.add_node(pre_state) + sdutil.change_edge_dest(graph, state, pre_state) + graph.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: @@ -200,9 +202,9 @@ def apply(self, _, sdfg): sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) - sdfg.add_node(post_state) - sdutil.change_edge_src(sdfg, state, post_state) - sdfg.add_edge(state, post_state, sd.InterstateEdge()) + graph.add_node(post_state) + sdutil.change_edge_src(graph, state, post_state) + graph.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): diff --git a/dace/transformation/interstate/move_assignment_outside_if.py b/dace/transformation/interstate/move_assignment_outside_if.py index 3b101818ca..6522c67eb8 100644 --- a/dace/transformation/interstate/move_assignment_outside_if.py +++ b/dace/transformation/interstate/move_assignment_outside_if.py @@ -1,58 +1,51 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Transformation to move assignments outside if statements to potentially avoid warp divergence. Speedup gained is questionable. """ import ast +from typing import Dict, List, Tuple import sympy as sp from dace import sdfg as sd -from dace.sdfg import graph as gr -from dace.sdfg.nodes import Tasklet, AccessNode +from dace.sdfg import graph as gr, utils as sdutil, nodes as nd +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion +from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class MoveAssignmentOutsideIf(transformation.MultiStateTransformation): - if_guard = transformation.PatternNode(sd.SDFGState) - if_stmt = transformation.PatternNode(sd.SDFGState) - else_stmt = transformation.PatternNode(sd.SDFGState) + conditional = transformation.PatternNode(ConditionalBlock) @classmethod def expressions(cls): - sdfg = gr.OrderedDiGraph() - sdfg.add_nodes_from([cls.if_guard, cls.if_stmt, cls.else_stmt]) - sdfg.add_edge(cls.if_guard, cls.if_stmt, sd.InterstateEdge()) - sdfg.add_edge(cls.if_guard, cls.else_stmt, sd.InterstateEdge()) - return [sdfg] + return [sdutil.node_path_graph(cls.conditional)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # The if-guard can only have two outgoing edges: to the if and to the else part - guard_outedges = graph.out_edges(self.if_guard) - if len(guard_outedges) != 2: + # The conditional can only have two branches, with conditions either being negations of one another, or the + # second branch being an 'else' branch. + if len(self.conditional.branches) != 2: return False - - # Outgoing edges must be a negation of each other - if guard_outedges[0].data.condition_sympy() != (sp.Not(guard_outedges[1].data.condition_sympy())): - return False - - # The if guard should either have zero or one incoming edge - if len(sdfg.in_edges(self.if_guard)) > 1: + fcond = self.conditional.branches[0][0] + scond = self.conditional.branches[1][0] + if (fcond is None or (scond is not None and + (pystr_to_symbolic(fcond.as_string)) != sp.Not(pystr_to_symbolic(scond.as_string)))): return False # set of the variables which get a const value assigned assigned_const = set() # Dict which collects all AccessNodes for each variable together with its state - access_nodes = {} + access_nodes: Dict[str, List[Tuple[nd.AccessNode, sd.SDFGState]]] = {} # set of the variables which are only written to self.write_only_values = set() # Dictionary which stores additional information for the variables which are written only self.assign_context = {} - for state in [self.if_stmt, self.else_stmt]: + for state in self.conditional.all_states(): for node in state.nodes(): - if isinstance(node, Tasklet): + if isinstance(node, nd.Tasklet): # If node is a tasklet, check if assigns a constant value assigns_const = True for code_stmt in node.code.code: @@ -60,10 +53,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): assigns_const = False if assigns_const: for edge in state.out_edges(node): - if isinstance(edge.dst, AccessNode): + if isinstance(edge.dst, nd.AccessNode): assigned_const.add(edge.dst.data) - self.assign_context[edge.dst.data] = {"state": state, "tasklet": node} - elif isinstance(node, AccessNode): + self.assign_context[edge.dst.data] = {'state': state, 'tasklet': node} + elif isinstance(node, nd.AccessNode): if node.data not in access_nodes: access_nodes[node.data] = [] access_nodes[node.data].append((node, state)) @@ -92,14 +85,14 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False return True - def apply(self, _, sdfg: sd.SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # create a new state before the guard state where the zero assignment happens - new_assign_state = sdfg.add_state_before(self.if_guard, label="const_assignment_state") + new_assign_state = graph.add_state_before(self.conditional, label='const_assignment_state') # Move all the Tasklets together with the AccessNode for value in self.write_only_values: - state = self.assign_context[value]["state"] - tasklet = self.assign_context[value]["tasklet"] + state: sd.SDFGState = self.assign_context[value]['state'] + tasklet: nd.Tasklet = self.assign_context[value]['tasklet'] new_assign_state.add_node(tasklet) for edge in state.out_edges(tasklet): state.remove_edge(edge) @@ -110,5 +103,4 @@ def apply(self, _, sdfg: sd.SDFG): state.remove_node(tasklet) # Remove the state if it was emptied if state.is_empty(): - sdfg.remove_node(state) - return sdfg + graph.remove_node(state) diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 30f5595594..ce913b8f66 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -5,7 +5,7 @@ from dace import properties from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion from dace.sdfg.utils import dfs_conditional from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.interstate.loop_lifting import LoopLifting diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py new file mode 100644 index 0000000000..111944614c --- /dev/null +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -0,0 +1,62 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Optional +from dace import properties +from dace.frontend.python import astutils +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, SDFGState +from dace.transformation import pass_pipeline as ppl, transformation + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass): + + CATEGORY: str = 'Simplification' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def apply(self, region: ControlFlowRegion, _) -> Optional[int]: + if not isinstance(region, ConditionalBlock): + return None + removed_branches = 0 + all_branches = region.branches + has_else = all_branches[-1][0] is None + new_else_cond = None + for cond, branch in all_branches: + branch_nodes = branch.nodes() + if (len(branch_nodes) == 0 or (len(branch_nodes) == 1 and isinstance(branch_nodes[0], SDFGState) and + len(branch_nodes[0].nodes()) == 0)): + # Found a branch we can eliminate. + if has_else and branch is not all_branches[-1][1]: + # If this conditional has an else branch and that is not the branch being eliminated, we need to + # change that else branch to a conditional else-if branch that negates the current branch's + # condition. + negated_condition = astutils.negate_expr(cond.code[0]) + if new_else_cond is None: + new_else_cond = properties.CodeBlock([negated_condition]) + else: + combined_cond = astutils.and_expr(negated_condition, new_else_cond.code[0]) + new_else_cond = properties.CodeBlock([combined_cond]) + region.remove_branch(branch) + else: + # Simple case, eliminate the branch. + region.remove_branch(branch) + removed_branches += 1 + # If the else branch remains, make sure it now has the new negate-all condition. + if new_else_cond is not None and region.branches[-1][0] is None: + region._branches[-1] = (new_else_cond, region._branches[-1][1]) + + if len(region.branches) == 0: + # The conditional has become entirely empty, remove it. + replacement_node_before = region.parent_graph.add_state_before(region) + replacement_node_after = region.parent_graph.add_state_before(region) + region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge()) + region.parent_graph.remove_node(region) + + return removed_branches if removed_branches > 0 else None + diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index bd28f8d377..c6177966f4 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -17,6 +17,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols from dace.transformation.passes.reference_reduction import ReferenceToView from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches SIMPLIFY_PASSES = [ InlineSDFGs, @@ -27,6 +28,7 @@ ConstantPropagation, DeadDataflowElimination, DeadStateElimination, + PruneEmptyConditionalBranches, RemoveUnusedSymbols, ReferenceToView, ArrayElimination, diff --git a/tests/passes/simplification/prune_empty_conditional_branches_test.py b/tests/passes/simplification/prune_empty_conditional_branches_test.py new file mode 100644 index 0000000000..65463ad3a7 --- /dev/null +++ b/tests/passes/simplification/prune_empty_conditional_branches_test.py @@ -0,0 +1,105 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + + +import numpy as np +import dace +from dace.sdfg.state import ConditionalBlock +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches + + +def test_prune_empty_else(): + N = dace.symbol('N') + + @dace.program + def prune_empty_else(A: dace.int32[N]): + A[:] = 0 + if N == 32: + for i in range(N): + A[i] = 1 + else: + A[:] = 0 + + sdfg = prune_empty_else.to_sdfg(simplify=False) + + conditional: ConditionalBlock = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + + assert len(conditional.branches) == 2 + + conditional._branches[-1][0] = None + else_branch = conditional._branches[-1][1] + else_branch.remove_nodes_from(else_branch.nodes()) + else_branch.add_state('empty') + + res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) + + assert res[conditional] == 1 + assert len(conditional.branches) == 1 + + N1 = 32 + N2 = 31 + A1 = np.zeros((N1,), dtype=np.int32) + A2 = np.zeros((N2,), dtype=np.int32) + verif1 = np.full((N1,), 1, dtype=np.int32) + verif2 = np.zeros((N2,), dtype=np.int32) + + sdfg(A1, N=N1) + sdfg(A2, N=N2) + + assert np.allclose(A1, verif1) + assert np.allclose(A2, verif2) + + +def test_prune_empty_if_with_else(): + N = dace.symbol('N') + + @dace.program + def prune_empty_if_with_else(A: dace.int32[N]): + A[:] = 0 + if N == 32: + for i in range(N): + A[i] = 2 + else: + A[:] = 1 + + sdfg = prune_empty_if_with_else.to_sdfg(simplify=False) + + conditional: ConditionalBlock = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + + assert len(conditional.branches) == 2 + + conditional._branches[-1][0] = None + if_branch = conditional._branches[0][1] + if_branch.remove_nodes_from(if_branch.nodes()) + if_branch.add_state('empty') + + res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) + + assert res[conditional] == 1 + assert len(conditional.branches) == 1 + assert conditional.branches[0][0] is not None + + N1 = 32 + N2 = 31 + A1 = np.zeros((N1,), dtype=np.int32) + A2 = np.zeros((N2,), dtype=np.int32) + verif1 = np.zeros((N1,), dtype=np.int32) + verif2 = np.full((N2,), 1, dtype=np.int32) + + sdfg(A1, N=N1) + sdfg(A2, N=N2) + + assert np.allclose(A1, verif1) + assert np.allclose(A2, verif2) + + +if __name__ == '__main__': + test_prune_empty_else() + test_prune_empty_if_with_else() diff --git a/tests/transformations/gpu_transform_test.py b/tests/transformations/gpu_transform_test.py index f6d299e630..2099077d81 100644 --- a/tests/transformations/gpu_transform_test.py +++ b/tests/transformations/gpu_transform_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Unit tests for the GPU to-device transformation. """ import dace diff --git a/tests/transformations/move_assignment_outside_if_test.py b/tests/transformations/move_assignment_outside_if_test.py index 270fd8f842..13725738e7 100644 --- a/tests/transformations/move_assignment_outside_if_test.py +++ b/tests/transformations/move_assignment_outside_if_test.py @@ -1,5 +1,7 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + import dace +from dace.sdfg.state import ConditionalBlock from dace.transformation.interstate import MoveAssignmentOutsideIf from dace.sdfg import InterstateEdge from dace.memlet import Memlet @@ -35,18 +37,23 @@ def one_variable_simple_test(const_value: int = 0): # Create if-else condition such that either the formula state or the const state is executed sdfg.add_edge(guard, formula_state, InterstateEdge(condition='B[0] < 0.5')) sdfg.add_edge(guard, const_state, InterstateEdge(condition='B[0] >= 0.5')) + sdfg.simplify() sdfg.validate() # Assure transformation is applied assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + sdfg.simplify() # SDFG now starts with a state containing the const_tasklet - assert const_tasklet in sdfg.start_state.nodes() - # The formula state has only one in_edge with the condition - assert len(sdfg.in_edges(formula_state)) == 1 - assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(B[0] < 0.5)' - # All state have at most one out_edge -> there is no if-else branching anymore - for state in sdfg.states(): - assert len(sdfg.out_edges(state)) <= 1 + assert const_tasklet in sdfg.start_block.nodes() + # There should now only be one conditional branch remaining in the entire SDFG. + conditional = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + assert conditional is not None + assert len(conditional.branches) == 1 + assert conditional.branches[0][0].as_string == '(B[0] < 0.5)' def multiple_variable_test(): @@ -89,21 +96,26 @@ def multiple_variable_test(): # Create if-else condition such that either the formula state or the const state is executed sdfg.add_edge(guard, formula_state, InterstateEdge(condition='D[0] < 0.5')) sdfg.add_edge(guard, const_state, InterstateEdge(condition='D[0] >= 0.5')) + sdfg.simplify() sdfg.validate() # Assure transformation is applied assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + sdfg.simplify() # There are no other tasklets in the start state beside the const assignment tasklet as there are no other const # assignments - for node in sdfg.start_state.nodes(): + for node in sdfg.start_block.nodes(): if isinstance(node, Tasklet): assert node == const_tasklet_a or node == const_tasklet_b - # The formula state has only one in_edge with the condition - assert len(sdfg.in_edges(formula_state)) == 1 - assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(D[0] < 0.5)' - # All state have at most one out_edge -> there is no if-else branching anymore - for state in sdfg.states(): - assert len(sdfg.out_edges(state)) <= 1 + # There should now only be one conditional branch remaining in the entire SDFG. + conditional = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + assert conditional is not None + assert len(conditional.branches) == 1 + assert conditional.branches[0][0].as_string == '(D[0] < 0.5)' def multiple_variable_not_all_const_test(): @@ -145,6 +157,7 @@ def multiple_variable_not_all_const_test(): # Create if-else condition such that either the formula state or the const state is executed sdfg.add_edge(guard, formula_state, InterstateEdge(condition='C[0] < 0.5')) sdfg.add_edge(guard, const_state, InterstateEdge(condition='C[0] >= 0.5')) + sdfg.simplify() sdfg.validate() # Assure transformation is applied @@ -154,24 +167,18 @@ def multiple_variable_not_all_const_test(): for node in sdfg.start_state.nodes(): if isinstance(node, Tasklet): assert node == const_tasklet_a or node == const_tasklet_b - # The formula state has only one in_edge with the condition - assert len(sdfg.in_edges(formula_state)) == 1 - assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(C[0] < 0.5)' - # Guard still has two outgoing edges as if-else pattern still exists - assert len(sdfg.out_edges(guard)) == 2 - # const state now has only const_tasklet_b left plus two access nodes - assert len(const_state.nodes()) == 3 - for node in const_state.nodes(): - if isinstance(node, Tasklet): - assert node == const_tasklet_b + # The conditional should still have two conditional branches + conditional = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + assert conditional is not None + assert len(conditional.branches) == 2 -def main(): +if __name__ == '__main__': one_variable_simple_test(0) one_variable_simple_test(2) multiple_variable_test() multiple_variable_not_all_const_test() - - -if __name__ == '__main__': - main() From b4feddfe8459dcb49e46ef416512e022624d2e60 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 11 Oct 2024 13:18:49 +0200 Subject: [PATCH 037/108] Added code to `test_more_than_a_map` to ensure that the transformation does not change the behaviour of teh output. --- .../move_loop_into_map_test.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index dca775bb7a..de610d0ca8 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -2,6 +2,7 @@ import dace from dace.transformation.interstate import MoveLoopIntoMap import unittest +import copy import numpy as np I = dace.symbol("I") @@ -170,7 +171,26 @@ def test_more_than_a_map(self): body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) + + sdfg_args_ref = { + "A": np.array(np.random.rand(3, 3), dtype=np.float64), + "B": np.array(np.random.rand(3, 3), dtype=np.float64), + "out": np.array(np.random.rand(3, 3), dtype=np.float64), + } + sdfg_args_res = copy.deepcopy(sdfg_args_ref) + + # Perform the reference execution + sdfg(**sdfg_args_ref) + + # Apply the transformation and execute the SDFG again. + count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) + sdfg(**sdfg_args_res) + + for name in sdfg_args_ref.keys(): + self.assertTrue( + np.allclose(sdfg_args_ref[name], sdfg_args_res[name]), + f"Miss match for {name}", + ) self.assertFalse(count > 0) def test_more_than_a_map_1(self): From 324fa34538751e4413a63bceaed16adc0a53b6f5 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 11 Oct 2024 14:36:42 +0200 Subject: [PATCH 038/108] More fixes --- dace/sdfg/utils.py | 42 +++++--- dace/transformation/interstate/__init__.py | 1 + .../transformation/interstate/block_fusion.py | 100 ++++++++++++++++++ .../interstate/multistate_inline.py | 48 ++++----- dace/transformation/passes/prune_symbols.py | 42 ++++---- tests/inlining_test.py | 20 ++-- tests/multistate_init_test.py | 2 +- 7 files changed, 190 insertions(+), 65 deletions(-) create mode 100644 dace/transformation/interstate/block_fusion.py diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index ed179df0cf..a272057380 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1195,7 +1195,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> shows progress bar. :return: The total number of states fused. """ - from dace.transformation.interstate import StateFusion # Avoid import loop + from dace.transformation.interstate import StateFusion, BlockFusion # Avoid import loop + if progress is None and not config.Config.get_bool('progress'): progress = False @@ -1228,20 +1229,33 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> progress = True pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) or - not isinstance(u, SDFGState)): + if u in skip_nodes or v in skip_nodes: continue - candidate = {StateFusion.first_state: u, StateFusion.second_state: v} - sf = StateFusion() - sf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) - if sf.can_be_applied(cfg, 0, sd, permissive=permissive): - sf.apply(cfg, sd) - applied += 1 - counter += 1 - if progress: - pbar.update(1) - skip_nodes.add(u) - skip_nodes.add(v) + + if isinstance(u, SDFGState) and isinstance(v, SDFGState): + candidate = {StateFusion.first_state: u, StateFusion.second_state: v} + sf = StateFusion() + sf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): + sf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) + else: + candidate = {BlockFusion.first_block: u, BlockFusion.second_block: v} + bf = BlockFusion() + bf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) + if bf.can_be_applied(cfg, 0, sd, permissive=permissive): + bf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) if applied == 0: break if progress: diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index b8bcc716e6..a53152e09c 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ This module initializes the inter-state transformations package.""" +from .block_fusion import BlockFusion from .state_fusion import StateFusion from .state_fusion_with_happens_before import StateFusionExtended from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination, diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py new file mode 100644 index 0000000000..a890551ba1 --- /dev/null +++ b/dace/transformation/interstate/block_fusion.py @@ -0,0 +1,100 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace.sdfg import utils as sdutil +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, SDFGState +from dace.transformation import transformation + + +@transformation.experimental_cfg_block_compatible +class BlockFusion(transformation.MultiStateTransformation): + """ Implements the state-fusion transformation. + + State-fusion takes two states that are connected through a single edge, + and fuses them into one state. If permissive, also applies if potential memory + access hazards are created. + """ + + first_block = transformation.PatternNode(ControlFlowBlock) + second_block = transformation.PatternNode(ControlFlowBlock) + + @staticmethod + def annotates_memlets(): + return False + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.first_block, cls.second_block)] + + def _is_noop(self, block: ControlFlowBlock) -> bool: + if isinstance(block, SDFGState): + return block.is_empty() + elif type(block) == ControlFlowBlock: + return True + return False + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # First block must have only one unconditional output edge (with dst the second block). + out_edges = graph.out_edges(self.first_block) + if len(out_edges) != 1 or out_edges[0].dst is not self.second_block or not out_edges[0].data.is_unconditional(): + return False + # Inversely, the second block may only have one input edge, with src being the first block. + in_edges_second = graph.in_edges(self.second_block) + if len(in_edges_second) != 1 or in_edges_second[0].src is not self.first_block: + return False + + # Ensure that either that both blocks are fusable blocks, meaning that at least one of the two blocks must be + # a 'no-op' block. That can be an empty SDFGState or a general control flow block without further semantics + # (no loop, conditional, break, continue, control flow region, etc.). + if not self._is_noop(self.first_block) and not self._is_noop(self.second_block): + return False + + # The interstate edge may have assignments if there are input edges to the first block that can absorb them. + in_edges = graph.in_edges(self.first_block) + if out_edges[0].data.assignments: + if not in_edges: + return False + # Fail if symbol is set before the block to fuse + new_assignments = set(out_edges[0].data.assignments.keys()) + if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): + return False + # Fail if symbol is used in the dataflow of that block + if len(new_assignments & self.first_block.free_symbols) > 0: + return False + # Fail if symbols assigned on the first edge are free symbols on the second edge + symbols_used = set(out_edges[0].data.free_symbols) + for e in in_edges: + if e.data.assignments.keys() & symbols_used: + return False + # Also fail in the inverse; symbols assigned on the second edge are free symbols on the first edge + if new_assignments & set(e.data.free_symbols): + return False + + # There can be no block that has output edges pointing to both the first and the second block. Such a case will + # produce a multi-graph. + for src, _, _ in in_edges: + for _, dst, _ in graph.out_edges(src): + if dst == self.second_block: + return False + return True + + def apply(self, graph: ControlFlowRegion, sdfg): + connecting_edge = graph.edges_between(self.first_block, self.second_block)[0] + assignments_to_absorb = connecting_edge.data.assignments + graph.remove_edge(connecting_edge) + for ie in graph.out_edges(self.first_block): + if assignments_to_absorb: + ie.data.assignments.update(assignments_to_absorb) + + if self._is_noop(self.first_block): + # We remove the first block and let the second one remain. + first_is_start = graph.start_block is self.first_block + for ie in graph.in_edges(self.first_block): + graph.add_edge(ie.src, self.second_block, ie.data) + graph.remove_node(self.first_block) + if first_is_start: + graph.start_block = self.second_block.block_id + else: + # We remove the second block and let the first one remain. + for oe in graph.out_edges(self.second_block): + graph.add_edge(self.first_block, oe.dst, oe.data) + graph.remove_node(self.second_block) diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 42dccd8616..c39c5868d4 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Inline multi-state SDFGs. """ from copy import deepcopy as dc @@ -18,7 +18,7 @@ @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class InlineMultistateSDFG(transformation.SingleStateTransformation): """ Inlines a multi-state nested SDFG into a top-level SDFG. This only happens @@ -74,7 +74,7 @@ def replfunc(mapping): return all(istr == ostr for istr, ostr in zip(istrides, ostrides)) - def can_be_applied(self, state: SDFGState, expr_index, sdfg, permissive=False): + def can_be_applied(self, state: SDFGState, expr_index, sdfg: SDFG, permissive=False): nested_sdfg = self.nested_sdfg if nested_sdfg.no_inline: return False @@ -146,14 +146,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sdfg.append_exit_code(code.code, loc) # Environments - for nstate in nsdfg.nodes(): + for nstate in nsdfg.states(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} - for ise in sdfg.edges(): + for ise in sdfg.all_interstate_edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Isolate nsdfg in a separate state @@ -188,11 +188,11 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Collect and modify interstate edges as necessary outer_assignments = set() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() - for e in nsdfg.edges(): + for e in nsdfg.all_interstate_edges(): inner_assignments |= e.data.assignments.keys() allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) @@ -234,7 +234,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # All transients become transients of the parent (if data already # exists, find new name) - for nstate in nsdfg.nodes(): + for nstate in nsdfg.states(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] @@ -326,8 +326,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e.dst_conn, e.data) # Make unique names for states - statenames = set(s.label for s in sdfg.nodes()) - for nstate in nsdfg.nodes(): + statenames = set(s.label for s in sdfg.states()) + for nstate in nsdfg.states(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) @@ -336,11 +336,11 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): ####################################################### # Add nested SDFG states into top-level SDFG - outer_start_state = sdfg.start_state + outer_start_state = outer_state.parent_graph.start_block - sdfg.add_nodes_from(nsdfg.nodes()) + outer_state.parent_graph.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): - sdfg.add_edge(ise.src, ise.dst, ise.data) + outer_state.parent_graph.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG @@ -349,19 +349,19 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sinks = nsdfg.sink_nodes() # Reconnect state machine - for e in sdfg.in_edges(nsdfg_state): - sdfg.add_edge(e.src, source, e.data) - for e in sdfg.out_edges(nsdfg_state): + for e in outer_state.parent_graph.in_edges(nsdfg_state): + outer_state.parent_graph.add_edge(e.src, source, e.data) + for e in outer_state.parent_graph.out_edges(nsdfg_state): for sink in sinks: - sdfg.add_edge(sink, e.dst, dc(e.data)) + outer_state.parent_graph.add_edge(sink, e.dst, dc(e.data)) # Redirect sink incoming edges with a `False` condition to e.dst (return statements) - for e2 in sdfg.in_edges(sink): + for e2 in outer_state.parent_graph.in_edges(sink): if e2.data.condition_sympy() == False: - sdfg.add_edge(e2.src, e.dst, InterstateEdge()) + outer_state.parent_graph.add_edge(e2.src, e.dst, InterstateEdge()) # Modify start state as necessary if outer_start_state is nsdfg_state: - sdfg.start_state = sdfg.node_id(source) + outer_state.parent_graph.start_block = outer_state.parent_graph.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once @@ -406,8 +406,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e.data, outer_edge.data) # Replace nested SDFG parents with new SDFG - for nstate in nsdfg.nodes(): - nstate.parent = sdfg + for nstate in nsdfg.states(): + nstate.sdfg = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg @@ -415,9 +415,9 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): ####################################################### # Remove nested SDFG and state - sdfg.remove_node(nsdfg_state) + outer_state.parent_graph.remove_node(nsdfg_state) - sdfg._cfg_list = sdfg.reset_cfg_list() + sdfg.reset_cfg_list() return nsdfg.nodes() diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 3b3940f804..a8385b493f 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -6,12 +6,13 @@ from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes +from dace.sdfg.state import SDFGState from dace.transformation import pass_pipeline as ppl, transformation @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class RemoveUnusedSymbols(ppl.Pass): """ Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``) and interstate edges. @@ -64,7 +65,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]: sid = sdfg.cfg_id result = set((sid, sym) for sym in result) - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): old_symbols = self.symbols @@ -90,27 +91,28 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) - for state in sdfg.nodes(): - result |= state.free_symbols + for block in sdfg.all_control_flow_blocks(): + result |= block.free_symbols # In addition to the standard free symbols, we are conservative with other tasklet languages by # tokenizing their code. Since this is intersected with `sdfg.symbols`, keywords such as "if" are # ok to include - for node in state.nodes(): - if isinstance(node, nodes.Tasklet): - if node.code.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_global.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_init.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_exit.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - - for e in sdfg.edges(): + if isinstance(block, SDFGState): + for node in block.nodes(): + if isinstance(node, nodes.Tasklet): + if node.code.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_global.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_init.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_exit.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + + for e in sdfg.all_interstate_edges(): result |= e.data.free_symbols return result diff --git a/tests/inlining_test.py b/tests/inlining_test.py index 7c3510daed..9d802e6bec 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace.sdfg.state import FunctionCallRegion, NamedRegion from dace.transformation.interstate import InlineSDFG, StateFusion from dace.libraries import blas from dace.library import change_default @@ -134,7 +135,7 @@ def outerprog(A: dace.float64[20]): from dace.transformation.interstate import InlineMultistateSDFG sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (4, 5) + assert sdfg.number_of_nodes() in (1, 2) sdfg(A) assert np.allclose(A, expected) @@ -145,14 +146,14 @@ def test_multistate_inline_samename(): @dace.program def nested(A: dace.float64[20]): for i in range(5): - A[i] += A[i - 1] + A[i + 1] += A[i] @dace.program def outerprog(A: dace.float64[20]): for i in range(5): nested(A) - sdfg = outerprog.to_sdfg(simplify=True) + sdfg = outerprog.to_sdfg(simplify=False) A = np.random.rand(20) expected = np.copy(A) @@ -160,7 +161,8 @@ def outerprog(A: dace.float64[20]): from dace.transformation.interstate import InlineMultistateSDFG sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (7, 8) + sdfg.simplify() + assert sdfg.number_of_nodes() == 1 sdfg(A) assert np.allclose(A, expected) @@ -193,8 +195,11 @@ def outerprog(A: dace.float64[20], B: dace.float64[20]): b = 2 * a sdfg = outerprog.to_sdfg(simplify=False) + for cf in sdfg.all_control_flow_regions(): + if isinstance(cf, (FunctionCallRegion, NamedRegion)): + cf.inline() sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) - assert len(sdfg.states()) == 1 + assert len(sdfg.nodes()) == 1 A = np.random.rand(20) B = np.random.rand(20) @@ -229,9 +234,12 @@ def outerprog(A: dace.float64[10], B: dace.float64[10], C: dace.float64[10]): c = 2 * a sdfg = outerprog.to_sdfg(simplify=False) + for cf in sdfg.all_control_flow_regions(): + if isinstance(cf, (FunctionCallRegion, NamedRegion)): + cf.inline() dace.propagate_memlets_sdfg(sdfg) sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) - assert len(sdfg.states()) == 1 + assert len(sdfg.nodes()) == 1 assert len([node for node in sdfg.start_state.data_nodes()]) == 3 A = np.random.rand(10) diff --git a/tests/multistate_init_test.py b/tests/multistate_init_test.py index 3359ce3e56..8efc68d260 100644 --- a/tests/multistate_init_test.py +++ b/tests/multistate_init_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace import numpy as np From 70fa3db43020073293dca5bdf9d2bf6649c93814 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 14 Oct 2024 09:29:58 +0200 Subject: [PATCH 039/108] Added the new memlet creation syntax. --- tests/sdfg/state_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 7d74ae6bcc..33e02088a4 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -71,17 +71,17 @@ def test_read_and_write_set_filter(): state.add_nedge( A, B, - dace.Memlet("B[0] -> 0, 0"), + dace.Memlet("B[0] -> [0, 0]"), ) state.add_nedge( B, C, - dace.Memlet("C[1, 1] -> 0"), + dace.Memlet("C[1, 1] -> [0]"), ) state.add_nedge( B, C, - dace.Memlet("B[0] -> 0, 0"), + dace.Memlet("B[0] -> [0, 0]"), ) sdfg.validate() From b187a8260ba197280b7b1c3cd8075a6a0f47a810 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 14 Oct 2024 13:15:13 +0200 Subject: [PATCH 040/108] Modified some comments to make them clearer. --- dace/sdfg/state.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 855cf83f62..53f8d98491 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -766,7 +766,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, subgraph_write_set = collections.defaultdict(list) # Traverse in topological order, so data that is written before it # is read is not counted in the read set - # TODO: This only works if every data descriptor is only once in a path. + # NOTE: Each AccessNode is processed individually. Thus, if an array appears multiple + # times in a path, the individual results are combined, without further processing. for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()): if not isinstance(n, nd.AccessNode): # Read and writes can only be done through access nodes, @@ -808,7 +809,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # exists a memlet at the same access node that writes to the same region, then # this read is ignored, and not included in the final read set, but only # accounted fro in the write set. See also note below. - # TODO: Handle the case when multiple disjoint writes are needed to cover the read. + # TODO: Handle the case when multiple disjoint writes would be needed to cover the + # read. E.g. edges write `0:10` and `10:20` but the read happens at `5:15`. for out_edge in list(out_edges): for in_edge in in_edges: if in_subsets[in_edge].covers(out_subsets[out_edge]): From 9c6cb6c91cd924cb9269c727192778106fe36a17 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 14 Oct 2024 13:15:48 +0200 Subject: [PATCH 041/108] Modified the `tests/transformations/move_loop_into_map_test.py::test_more_than_a_map()` test. The test now assumes that it can be applied, as it was discussed on github this should be the correct behaviour. Furthermore, there is also a test which ensures that the same values are computed. This commit also adds an additional test (``tests/transformations/move_loop_into_map_test.py::test_more_than_a_map_4()`) which is essentially the same. However, in one of its Memlet it is a bit different and therefore the transformation can not be applied. --- .../move_loop_into_map_test.py | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index de610d0ca8..fbb05d30f5 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -148,7 +148,10 @@ def test_apply_multiple_times_1(self): self.assertTrue(np.allclose(val, ref)) def test_more_than_a_map(self): - """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """ + """ + `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. + However, there is no dependency. + """ sdfg = dace.SDFG('more_than_a_map') _, aarr = sdfg.add_array('A', (3, 3), dace.float64) _, barr = sdfg.add_array('B', (3, 3), dace.float64) @@ -168,7 +171,7 @@ def test_more_than_a_map(self): external_edges=True, input_nodes=dict(out=oread, B=bread), output_nodes=dict(tmp=twrite)) - body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) + body.add_nedge(aread, oread, dace.Memlet.from_array('A', oarr)) body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') @@ -191,7 +194,7 @@ def test_more_than_a_map(self): np.allclose(sdfg_args_ref[name], sdfg_args_res[name]), f"Miss match for {name}", ) - self.assertFalse(count > 0) + self.assertTrue(count > 0) def test_more_than_a_map_1(self): """ @@ -289,6 +292,55 @@ def test_more_than_a_map_3(self): count = sdfg.apply_transformations(MoveLoopIntoMap) self.assertFalse(count > 0) + def test_more_than_a_map_4(self): + """ + The test is very similar to `test_more_than_a_map()`. But a memlet is different + which leads to a RW dependency, which blocks the transformation. + """ + sdfg = dace.SDFG('more_than_a_map') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + body = sdfg.add_state('map_state') + aread = body.add_access('A') + oread = body.add_access('out') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(out=oread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(aread, oread, dace.Memlet('A[Mod(_, 3), 0:3] -> [Mod(_ + 1, 3), 0:3]', aarr)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') + + sdfg_args_ref = { + "A": np.array(np.random.rand(3, 3), dtype=np.float64), + "B": np.array(np.random.rand(3, 3), dtype=np.float64), + "out": np.array(np.random.rand(3, 3), dtype=np.float64), + } + sdfg_args_res = copy.deepcopy(sdfg_args_ref) + + # Perform the reference execution + sdfg(**sdfg_args_ref) + + # Apply the transformation and execute the SDFG again. + count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) + sdfg(**sdfg_args_res) + + for name in sdfg_args_ref.keys(): + self.assertTrue( + np.allclose(sdfg_args_ref[name], sdfg_args_res[name]), + f"Miss match for {name}", + ) + self.assertFalse(count > 0) + if __name__ == '__main__': unittest.main() From 251833fcf119d77ed99d26695c2bcee97fee7fa4 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 15 Oct 2024 09:54:36 +0200 Subject: [PATCH 042/108] Adapt scalar fission and symbol ssa --- dace/transformation/pass_pipeline.py | 7 +- .../passes/analysis/analysis.py | 277 ++++++++++-------- dace/transformation/passes/scalar_fission.py | 4 +- dace/transformation/passes/symbol_ssa.py | 18 +- tests/passes/scalar_fission_test.py | 29 +- ...calar_write_shadow_scopes_analysis_test.py | 77 +++-- 6 files changed, 247 insertions(+), 165 deletions(-) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 2067168ca4..3689b3f4ac 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -271,8 +271,7 @@ class ControlFlowRegionPass(Pass): CATEGORY: str = 'Helper' - def apply_pass(self, sdfg: SDFG, - pipeline_results: Dict[str, Any]) -> Optional[Dict[ControlFlowRegion, Optional[Any]]]: + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[int, Optional[Any]]]: """ Applies the pass to control flow regions of the given SDFG by calling ``apply`` on each region. @@ -280,14 +279,14 @@ def apply_pass(self, sdfg: SDFG, :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. - :return: A dictionary of ``{region: return value}`` for visited regions with a non-None return value, or None + :return: A dictionary of ``{cfg_id: return value}`` for visited regions with a non-None return value, or None if nothing was returned. """ result = {} for region in sdfg.all_control_flow_regions(recursive=True): retval = self.apply(region, pipeline_results) if retval is not None: - result[region] = retval + result[region.cfg_id] = retval if not result: return None diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index d8c791f4e5..7c89dd9b79 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -1,19 +1,21 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion + +import blinker +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd -from dace.sdfg.analysis import cfg +from dace.sdfg.analysis import cfg as cfg_analysis from typing import Dict, Set, Tuple, Any, Optional, Union import networkx as nx from networkx.algorithms import shortest_paths as nxsp WriteScopeDict = Dict[str, Dict[Optional[Tuple[SDFGState, nd.AccessNode]], - Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]] -SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]] + Set[Union[Tuple[SDFGState, nd.AccessNode], Tuple[ControlFlowBlock, InterstateEdge]]]]] +SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], ControlFlowBlock]]]] @properties.make_properties @@ -101,7 +103,7 @@ def _region_closure(self, region: ControlFlowRegion, def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: """ :return: For each control flow region, a dictionary mapping each control flow block to its other reachable - control flow blocks in the same region. + control flow blocks. """ single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( lambda: defaultdict(set) @@ -183,7 +185,7 @@ def reachable_nodes(G): @properties.make_properties @transformation.experimental_cfg_block_compatible -class SymbolAccessSets(ppl.Pass): +class SymbolAccessSets(ppl.ControlFlowRegionPass): """ Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). """ @@ -197,33 +199,25 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass(self, top_sdfg: SDFG, - _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: - """ - :return: A dictionary mapping each state and interstate edge to a tuple of its (read, written) symbols. - """ - top_result: Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} - for sdfg in top_sdfg.all_sdfgs_recursive(): - for cfg in sdfg.all_control_flow_regions(): - adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for block in cfg.nodes(): - if isinstance(block, SDFGState): - # No symbols may be written to inside states. - result[block] = (block.free_symbols, set()) - for oedge in cfg.out_edges(block): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[cfg.cfg_id] = result - return top_result + def apply(self, region: ControlFlowRegion, _) -> Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], + Tuple[Set[str], Set[str]]]: + adesc = set(region.sdfg.arrays.keys()) + result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + for block in region.nodes(): + # No symbols may be written to inside blocks. + result[block] = (block.free_symbols, set()) + for oedge in region.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + return result @properties.make_properties @transformation.experimental_cfg_block_compatible class AccessSets(ppl.Pass): """ - Evaluates memory access sets (which arrays/data descriptors are read/written in each state). + Evaluates memory access sets (which arrays/data descriptors are read/written in each control flow block). """ CATEGORY: str = 'Analysis' @@ -232,25 +226,33 @@ def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply + # If access nodes were modified, reapply return modified & ppl.Modifies.AccessNodes - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.states(): + result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + for block in sdfg.all_control_flow_blocks(): readset, writeset = set(), set() - for anode in state.data_nodes(): - if state.in_degree(anode) > 0: - writeset.add(anode.data) - if state.out_degree(anode) > 0: - readset.add(anode.data) - - result[state] = (readset, writeset) + if isinstance(block, SDFGState): + for anode in block.data_nodes(): + if block.in_degree(anode) > 0: + writeset.add(anode.data) + if block.out_degree(anode) > 0: + readset.add(anode.data) + elif isinstance(block, AbstractControlFlowRegion): + for state in block.all_states(): + for anode in state.data_nodes(): + if state.in_degree(anode) > 0: + writeset.add(anode.data) + if state.out_degree(anode) > 0: + readset.add(anode.data) + + result[block] = (readset, writeset) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() @@ -341,10 +343,10 @@ def apply_pass(self, top_sdfg: SDFG, @properties.make_properties -@transformation.single_level_sdfg_only -class SymbolWriteScopes(ppl.Pass): +@transformation.experimental_cfg_block_compatible +class SymbolWriteScopes(ppl.ControlFlowRegionPass): """ - For each symbol, create a dictionary mapping each writing interstate edge to that symbol to the set of interstate + For each symbol, create a dictionary mapping each interstate edge writing to that symbol to the set of interstate edges and states reading that symbol that are dominated by that write. """ @@ -354,17 +356,16 @@ def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.Symbols | ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Nodes + return modified & ppl.Modifies.Symbols | ppl.Modifies.CFG | ppl.Modifies.Edges | ppl.Modifies.Nodes def depends_on(self): - return {SymbolAccessSets, StateReachability} + return {SymbolAccessSets, ControlFlowBlockReachability} - def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], - state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: - last_state: SDFGState = read if isinstance(read, SDFGState) else read.src + def _find_dominating_write(self, sym: str, read: Union[ControlFlowBlock, Edge[InterstateEdge]], + block_idom: Dict[ControlFlowBlock, ControlFlowBlock]) -> Optional[Edge[InterstateEdge]]: + last_block: ControlFlowBlock = read if isinstance(read, ControlFlowBlock) else read.src - in_edges = last_state.parent.in_edges(last_state) + in_edges = last_block.parent_graph.in_edges(last_block) deg = len(in_edges) if deg == 0: return None @@ -372,9 +373,9 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat return in_edges[0] write_isedge = None - n_state = state_idom[last_state] if state_idom[last_state] != last_state else None - while n_state is not None and write_isedge is None: - oedges = n_state.parent.out_edges(n_state) + n_block = block_idom[last_block] if block_idom[last_block] != last_block else None + while n_block is not None and write_isedge is None: + oedges = n_block.parent_graph.out_edges(n_block) odeg = len(oedges) if odeg == 1: if any([sym == k for k in oedges[0].data.assignments.keys()]): @@ -382,71 +383,68 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat else: dom_edge = None for cand in oedges: - if nxsp.has_path(n_state.parent.nx, cand.dst, last_state): + if nxsp.has_path(n_block.parent_graph.nx, cand.dst, last_block): if dom_edge is not None: dom_edge = None break elif any([sym == k for k in cand.data.assignments.keys()]): dom_edge = cand write_isedge = dom_edge - n_state = state_idom[n_state] if state_idom[n_state] != n_state else None + n_block = block_idom[n_block] if block_idom[n_block] != n_block else None return write_isedge - def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, SymbolScopeDict]: - top_result: Dict[int, SymbolScopeDict] = dict() + def apply(self, region, pipeline_results) -> SymbolScopeDict: + result: SymbolScopeDict = defaultdict(lambda: defaultdict(lambda: set())) - for sdfg in sdfg.all_sdfgs_recursive(): - result: SymbolScopeDict = defaultdict(lambda: defaultdict(lambda: set())) + idom = nx.immediate_dominators(region.nx, region.start_block) + all_doms = cfg_analysis.all_dominators(region, idom) - idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) - all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], - Tuple[Set[str], - Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.cfg_id] - state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id] + b_reach: Dict[ControlFlowBlock, + Set[ControlFlowBlock]] = pipeline_results[ControlFlowBlockReachability.__name__][region.cfg_id] + symbol_access_sets: Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], + Tuple[Set[str], Set[str]]] = pipeline_results[SymbolAccessSets.__name__][region.cfg_id] - for read_loc, (reads, _) in symbol_access_sets.items(): - for sym in reads: - dominating_write = self._find_dominating_write(sym, read_loc, idom) - result[sym][dominating_write].add(read_loc if isinstance(read_loc, SDFGState) else read_loc) + for read_loc, (reads, _) in symbol_access_sets.items(): + for sym in reads: + dominating_write = self._find_dominating_write(sym, read_loc, idom) + result[sym][dominating_write].add(read_loc if isinstance(read_loc, ControlFlowBlock) else read_loc.dst) - # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, - # then merge A and its scope into B's scope. - to_remove = set() - for sym in result.keys(): - for write, accesses in result[sym].items(): - if write is None: - continue - dominators = all_doms[write.dst] - reach = state_reach[write.dst] - for dom in dominators: - iedges = dom.parent.in_edges(dom) - if len(iedges) == 1 and iedges[0] in result[sym]: - other_accesses = result[sym][iedges[0]] - coarsen = False - for a_state_or_edge in other_accesses: - if isinstance(a_state_or_edge, SDFGState): - if a_state_or_edge in reach: - coarsen = True - break - else: - if a_state_or_edge.src in reach: - coarsen = True - break - if coarsen: - other_accesses.update(accesses) - other_accesses.add(write) - to_remove.add((sym, write)) - result[sym][write] = set() - for sym, write in to_remove: - del result[sym][write] - - top_result[sdfg.cfg_id] = result - return top_result + # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, then merge + # A and its scope into B's scope. + to_remove = set() + for sym in result.keys(): + for write, accesses in result[sym].items(): + if write is None: + continue + dominators = all_doms[write.dst] + reach = b_reach[write.dst] + for dom in dominators: + iedges = dom.parent_graph.in_edges(dom) + if len(iedges) == 1 and iedges[0] in result[sym]: + other_accesses = result[sym][iedges[0]] + coarsen = False + for a_state_or_edge in other_accesses: + if isinstance(a_state_or_edge, SDFGState): + if a_state_or_edge in reach: + coarsen = True + break + else: + if a_state_or_edge.src in reach: + coarsen = True + break + if coarsen: + other_accesses.update(accesses) + other_accesses.add(write) + to_remove.add((sym, write)) + result[sym][write] = set() + for sym, write in to_remove: + del result[sym][write] + + return result @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class ScalarWriteShadowScopes(ppl.Pass): """ For each scalar or array of size 1, create a dictionary mapping writes to that data container to the set of reads @@ -467,13 +465,14 @@ def depends_on(self): def _find_dominating_write(self, desc: str, - state: SDFGState, + block: ControlFlowBlock, read: Union[nd.AccessNode, InterstateEdge], access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], - state_idom: Dict[SDFGState, SDFGState], - access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], + idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]], + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]], no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): + state: SDFGState = block # If the read is also a write, it shadows itself. iedges = state.in_edges(read) if len(iedges) > 0 and any(not e.data.is_empty() for e in iedges) and not no_self_shadowing: @@ -489,24 +488,31 @@ def _find_dominating_write(self, closest_candidate = cand if closest_candidate is not None: return (state, closest_candidate) - elif isinstance(read, InterstateEdge): + elif isinstance(read, InterstateEdge) and isinstance(block, SDFGState): # Attempt to find a shadowing write in the current state. # TODO: Can this be done more efficiently? closest_candidate = None - write_nodes = access_nodes[desc][state][1] + write_nodes = access_nodes[desc][block][1] for cand in write_nodes: - if closest_candidate is None or nxsp.has_path(state._nx, closest_candidate, cand): + if closest_candidate is None or nxsp.has_path(block._nx, closest_candidate, cand): closest_candidate = cand if closest_candidate is not None: - return (state, closest_candidate) + return (block, closest_candidate) - # Find the dominating write state if the current state is not the dominating write state. + # Find the dominating write state if the current block is not the dominating write state. write_state = None - nstate = state_idom[state] if state_idom[state] != state else None - while nstate is not None and write_state is None: - if desc in access_sets[nstate][1]: - write_state = nstate - nstate = state_idom[nstate] if state_idom[nstate] != nstate else None + pivot_block = block + region = block.parent_graph + while region is not None and write_state is None: + nblock = idom_dict[region][pivot_block] if idom_dict[region][pivot_block] != block else None + while nblock is not None and write_state is None: + if isinstance(nblock, SDFGState) and desc in access_sets[nblock][1]: + write_state = nblock + nblock = idom_dict[region][nblock] if idom_dict[region][nblock] != nblock else None + # No dominating write found in the current control flow graph, check one further up. + if write_state is None: + pivot_block = region + region = region.parent_graph # Find a dominating write in the write state, i.e., the 'last' write to the data container. if write_state is not None: @@ -532,12 +538,28 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i for sdfg in top_sdfg.all_sdfgs_recursive(): result: WriteScopeDict = defaultdict(lambda: defaultdict(lambda: set())) - idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) - all_doms = cfg.all_dominators(sdfg, idom) - access_sets: Dict[SDFGState, Tuple[Set[str], - Set[str]]] = pipeline_results[AccessSets.__name__][sdfg.cfg_id] + idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} + all_doms_transitive: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(lambda: set()) + for cfg in sdfg.all_control_flow_regions(): + if isinstance(cfg, ConditionalBlock): + idom_dict[cfg] = {b: b for _, b in cfg.branches} + all_doms = {b: set([b]) for _, b in cfg.branches} + else: + idom_dict[cfg] = nx.immediate_dominators(cfg.nx, cfg.start_block) + all_doms = cfg_analysis.all_dominators(cfg, idom_dict[cfg]) + + # Since all_control_flow_regions goes top-down in the graph hierarchy, we can build a transitive + # closure of all dominators her. + for k in all_doms.keys(): + all_doms_transitive[k].update(all_doms[k]) + all_doms_transitive[k].add(cfg) + all_doms_transitive[k].update(all_doms_transitive[cfg]) + + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], + Set[str]]] = pipeline_results[AccessSets.__name__][sdfg.cfg_id] access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[ FindAccessNodes.__name__][sdfg.cfg_id] + state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id] anames = sdfg.arrays.keys() @@ -545,18 +567,19 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i desc_states_with_nodes = set(access_nodes[desc].keys()) for state in desc_states_with_nodes: for read_node in access_nodes[desc][state][0]: - write = self._find_dominating_write(desc, state, read_node, access_nodes, idom, access_sets) + write = self._find_dominating_write(desc, state, read_node, access_nodes, idom_dict, + access_sets) result[desc][write].add((state, read_node)) # Ensure accesses to interstate edges are also considered. - for state, accesses in access_sets.items(): + for block, accesses in access_sets.items(): if desc in accesses[0]: - out_edges = sdfg.out_edges(state) + out_edges = block.parent_graph.out_edges(block) for oedge in out_edges: syms = oedge.data.free_symbols & anames if desc in syms: - write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom, + write = self._find_dominating_write(desc, block, oedge.data, access_nodes, idom_dict, access_sets) - result[desc][write].add((state, oedge.data)) + result[desc][write].add((block, oedge.data)) # Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not # dominating any reads and are thus not part of the results yet. for state in desc_states_with_nodes: @@ -566,7 +589,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i state, write_node, access_nodes, - idom, + idom_dict, access_sets, no_self_shadowing=True) result[desc][write].add((state, write_node)) @@ -578,7 +601,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i if write is None: continue write_state, write_node = write - dominators = all_doms[write_state] + dominators = all_doms_transitive[write_state] reach = state_reach[write_state] for other_write, other_accesses in result[desc].items(): if other_write is not None and other_write[1] is write_node and other_write[0] is write_state: diff --git a/dace/transformation/passes/scalar_fission.py b/dace/transformation/passes/scalar_fission.py index f691a861d7..0b234f2961 100644 --- a/dace/transformation/passes/scalar_fission.py +++ b/dace/transformation/passes/scalar_fission.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from typing import Any, Dict, Optional, Set @@ -8,7 +8,7 @@ from dace.transformation.passes import analysis as ap -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class ScalarFission(ppl.Pass): """ Fission transient scalars or arrays of size 1 that are dominated by a write into separate data containers. diff --git a/dace/transformation/passes/symbol_ssa.py b/dace/transformation/passes/symbol_ssa.py index fa59f88df7..29dec5b861 100644 --- a/dace/transformation/passes/symbol_ssa.py +++ b/dace/transformation/passes/symbol_ssa.py @@ -1,14 +1,15 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from typing import Any, Dict, Optional, Set -from dace import SDFG, SDFGState +from dace import SDFG +from dace.sdfg.state import ControlFlowBlock from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap -@transformation.single_level_sdfg_only -class StrictSymbolSSA(ppl.Pass): +@transformation.experimental_cfg_block_compatible +class StrictSymbolSSA(ppl.ControlFlowRegionPass): """ Perform an SSA transformation on all symbols in the SDFG in a strict manner, i.e., without introducing phi nodes. """ @@ -24,19 +25,20 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {ap.SymbolWriteScopes} - def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Set[str]]]: + def apply(self, region, pipeline_results) -> Optional[Dict[str, Set[str]]]: """ Rename symbols in a restricted SSA manner. - :param sdfg: The SDFG to modify. + :param region: The control flow region to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :return: A dictionary mapping the original name to a set of all new names created for each symbol. """ results: Dict[str, Set[str]] = defaultdict(lambda: set()) + sdfg = region if isinstance(region, SDFG) else region.sdfg - symbol_scope_dict: ap.SymbolScopeDict = pipeline_results[ap.SymbolWriteScopes.__name__][sdfg.cfg_id] + symbol_scope_dict: ap.SymbolScopeDict = pipeline_results[ap.SymbolWriteScopes.__name__][region.cfg_id] for name, scope_dict in symbol_scope_dict.items(): # If there is only one scope, don't do anything. @@ -58,7 +60,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D # Replace all dominated reads. for read in shadowed_reads: - if isinstance(read, SDFGState): + if isinstance(read, ControlFlowBlock): read.replace(name, newname) else: if read not in scope_dict: diff --git a/tests/passes/scalar_fission_test.py b/tests/passes/scalar_fission_test.py index adf66f5b1d..a1f8d1d20e 100644 --- a/tests/passes/scalar_fission_test.py +++ b/tests/passes/scalar_fission_test.py @@ -6,9 +6,12 @@ import dace from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.scalar_fission import ScalarFission +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches -def test_scalar_fission(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_fission(with_raising = False): """ Test the scalar fission pass. This heavily relies on the scalar write shadow scopes pass, which is tested separately. @@ -95,6 +98,9 @@ def test_scalar_fission(): sdfg.add_edge(loop_2_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarFission()]) pipeline.apply_pass(sdfg, {}) @@ -107,7 +113,8 @@ def test_scalar_fission(): assert all([n.data == list(tmp1_edge.assignments.values())[0] for n in [tmp1_write, loop1_read_tmp]]) assert all([n.data == list(tmp2_edge.assignments.values())[0] for n in [tmp2_write, loop2_read_tmp]]) -def test_branch_subscopes_nofission(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_branch_subscopes_nofission(with_raising = False): sdfg = dace.SDFG('branch_subscope_fission') sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [2], dace.int32) @@ -185,11 +192,15 @@ def test_branch_subscopes_nofission(): right_after.add_edge(a10, None, t6, 'b', dace.Memlet('B[0]')) right_after.add_edge(t6, 'c', a11, None, dace.Memlet('C[0]')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + Pipeline([ScalarFission()]).apply_pass(sdfg, {}) assert set(sdfg.arrays.keys()) == {'A', 'B', 'C'} -def test_branch_subscopes_fission(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_branch_subscopes_fission(with_raising = False): sdfg = dace.SDFG('branch_subscope_fission') sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [2], dace.int32) @@ -277,11 +288,17 @@ def test_branch_subscopes_fission(): merge_1.add_edge(a13, None, t8, 'b', dace.Memlet('B[0]')) merge_1.add_edge(t8, 'c', a14, None, dace.Memlet('C[0]')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + Pipeline([ScalarFission()]).apply_pass(sdfg, {}) assert set(sdfg.arrays.keys()) == {'A', 'B', 'C', 'B_0', 'B_1'} if __name__ == '__main__': - test_scalar_fission() - test_branch_subscopes_nofission() - test_branch_subscopes_fission() + test_scalar_fission(False) + test_branch_subscopes_nofission(False) + test_branch_subscopes_fission(False) + test_scalar_fission(True) + test_branch_subscopes_nofission(True) + test_branch_subscopes_fission(True) diff --git a/tests/passes/scalar_write_shadow_scopes_analysis_test.py b/tests/passes/scalar_write_shadow_scopes_analysis_test.py index b833a12a94..f0648cc2ba 100644 --- a/tests/passes/scalar_write_shadow_scopes_analysis_test.py +++ b/tests/passes/scalar_write_shadow_scopes_analysis_test.py @@ -1,14 +1,16 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar write shadowing analysis pass. """ import pytest - import dace from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis import ScalarWriteShadowScopes +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches -def test_scalar_write_shadow_split(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_split(with_raising = False): """ Test the scalar write shadow scopes pass with writes dominating reads across state. """ @@ -90,6 +92,9 @@ def test_scalar_write_shadow_split(): sdfg.add_edge(loop_2_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -106,7 +111,8 @@ def test_scalar_write_shadow_split(): } -def test_scalar_write_shadow_fused(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_fused(with_raising = False): """ Test the scalar write shadow scopes pass with writes dominating reads in the same state. """ @@ -176,6 +182,9 @@ def test_scalar_write_shadow_fused(): sdfg.add_edge(loop_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -186,7 +195,8 @@ def test_scalar_write_shadow_fused(): assert results[0]['B'][None] == {(loop_1, b1_read), (loop_2, b2_read), (loop_1, b1_write), (loop_2, b2_write)} -def test_scalar_write_shadow_interstate_self(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_interstate_self(with_raising = False): """ Tests the scalar write shadow pass with interstate edge reads being shadowed by the state they're originating from. """ @@ -270,6 +280,9 @@ def test_scalar_write_shadow_interstate_self(): sdfg.add_edge(loop_2_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -286,7 +299,8 @@ def test_scalar_write_shadow_interstate_self(): } -def test_scalar_write_shadow_interstate_pred(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_interstate_pred(with_raising = False): """ Tests the scalar write shadow pass with interstate edge reads being shadowed by a predecessor state. """ @@ -374,6 +388,9 @@ def test_scalar_write_shadow_interstate_pred(): sdfg.add_edge(loop_2_3, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -390,7 +407,8 @@ def test_scalar_write_shadow_interstate_pred(): } -def test_loop_fake_shadow(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_loop_fake_shadow(with_raising = False): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -432,13 +450,17 @@ def test_loop_fake_shadow(): sdfg.add_edge(loop2, guard, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard, end, dace.InterstateEdge(condition='i >= 10')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] assert res[0]['A'][(init, init_access)] == {(loop, loop_access), (loop2, loop2_access), (end, end_access)} -def test_loop_fake_complex_shadow(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_loop_fake_complex_shadow(with_raising = False): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -472,13 +494,17 @@ def test_loop_fake_complex_shadow(): sdfg.add_edge(loop2, guard, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard, end, dace.InterstateEdge(condition='i >= 10')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] assert res[0]['A'][(init, init_access)] == {(loop, loop_access), (loop2, loop2_access)} -def test_loop_real_shadow(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_loop_real_shadow(with_raising = False): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -514,6 +540,9 @@ def test_loop_real_shadow(): sdfg.add_edge(loop2, guard, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard, end, dace.InterstateEdge(condition='i >= 10')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -521,7 +550,8 @@ def test_loop_real_shadow(): assert res[0]['A'][(loop2, loop2_access)] == {(loop2, loop2_access)} -def test_dominationless_write_branch(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_dominationless_write_branch(with_raising = False): sdfg = dace.SDFG('dominationless_write_branch') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -558,6 +588,9 @@ def test_dominationless_write_branch(): sdfg.add_edge(guard, merge, dace.InterstateEdge(condition='B[0] >= 10')) sdfg.add_edge(left, merge, dace.InterstateEdge()) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -566,11 +599,19 @@ def test_dominationless_write_branch(): if __name__ == '__main__': - test_scalar_write_shadow_split() - test_scalar_write_shadow_fused() - test_scalar_write_shadow_interstate_self() - test_scalar_write_shadow_interstate_pred() - test_loop_fake_shadow() - test_loop_fake_complex_shadow() - test_loop_real_shadow() - test_dominationless_write_branch() + test_scalar_write_shadow_split(False) + test_scalar_write_shadow_fused(False) + test_scalar_write_shadow_interstate_self(False) + test_scalar_write_shadow_interstate_pred(False) + test_loop_fake_shadow(False) + test_loop_fake_complex_shadow(False) + test_loop_real_shadow(False) + test_dominationless_write_branch(False) + test_scalar_write_shadow_split(True) + test_scalar_write_shadow_fused(True) + test_scalar_write_shadow_interstate_self(True) + test_scalar_write_shadow_interstate_pred(True) + test_loop_fake_shadow(True) + test_loop_fake_complex_shadow(True) + test_loop_real_shadow(True) + test_dominationless_write_branch(True) From 5b7bdad3029fc9658f720117f744def25c4cdb7b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 15 Oct 2024 10:25:26 +0200 Subject: [PATCH 043/108] Fixes --- dace/codegen/control_flow.py | 11 ++-- dace/transformation/dataflow/map_fission.py | 6 +- dace/transformation/helpers.py | 57 ++++++++++--------- .../passes/analysis/analysis.py | 1 - .../simplification/control_flow_raising.py | 3 + tests/sdfg/conditional_region_test.py | 1 + tests/sdfg/free_symbols_test.py | 4 +- 7 files changed, 45 insertions(+), 38 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index d7272a214f..6657d09808 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -1099,13 +1099,14 @@ def make_empty_block(): return visited - {stop} -def structured_control_flow_tree_with_regions(sdfg: SDFG, dispatch_state: Callable[[SDFGState], str]) -> ControlFlow: +def structured_control_flow_tree_with_regions(cfg: ControlFlowRegion, + dispatch_state: Callable[[SDFGState], str]) -> ControlFlow: """ - Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from an SDFG based on the + Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from a CFG based on the control flow regions it contains. - :param sdfg: The SDFG to iterate over. - :return: Control-flow block representing the entire SDFG. + :param cfg: The graph to iterate over. + :return: Control-flow block representing the entire graph. """ root_block = GeneralBlock(dispatch_state=dispatch_state, parent=None, @@ -1117,7 +1118,7 @@ def structured_control_flow_tree_with_regions(sdfg: SDFG, dispatch_state: Callab gotos_to_break=[], assignments_to_ignore=[], sequential=True) - _structured_control_flow_traversal_with_regions(sdfg, dispatch_state, root_block) + _structured_control_flow_traversal_with_regions(cfg, dispatch_state, root_block) _reset_block_parents(root_block) return root_block diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index 89e3d2d90f..f3a2be08b7 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -1,19 +1,19 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Map Fission transformation. """ from copy import deepcopy as dcpy from collections import defaultdict -from dace import registry, sdfg as sd, memlet as mm, subsets, data as dt +from dace import sdfg as sd, memlet as mm, subsets, data as dt from dace.codegen import control_flow as cf from dace.sdfg import nodes, graph as gr from dace.sdfg import utils as sdutil -from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.propagation import propagate_memlets_state, propagate_subset from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation, helpers from typing import List, Optional, Tuple +@transformation.single_level_sdfg_only class MapFission(transformation.SingleStateTransformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index da45c0bf04..c69ba5a149 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -4,7 +4,7 @@ import itertools from networkx import MultiDiGraph -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.subsets import Range, Subset, union import dace.subsets as subsets from typing import Dict, List, Optional, Tuple, Set, Union @@ -244,33 +244,36 @@ def _copy_state(sdfg: SDFG, return state_copy -def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: +def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, Set[ControlFlowBlock]]: """ - Partitions the SDFG to subgraphs that can be nested independently of each other. The method does not nest the - subgraphs but alters the SDFG; (1) interstate edges are split, (2) scope source/sink states that belong to multiple + Partitions a CFG to subgraphs that can be nested independently of each other. The method does not nest the + subgraphs but alters the graph; (1) interstate edges are split, (2) scope source/sink nodes that belong to multiple scopes are duplicated (see _copy_state). - :param sdfg: The SDFG to be partitioned. - :return: The found subgraphs in the form of a dictionary where the keys are the start state of the subgraphs and the - values are the sets of SDFGStates contained withing each subgraph. + :param cfg: The graph to be partitioned. + :return: The found subgraphs in the form of a dictionary where the keys are the start block of the subgraphs and the + values are the sets of ControlFlowBlocks contained withing each subgraph. """ - split_interstate_edges(sdfg) + split_interstate_edges(cfg) - # Create a unique sink state to avoid issues with finding control flow. - sink_states = sdfg.sink_nodes() - if len(sink_states) > 1: - new_sink = sdfg.add_state('common_sink') - for s in sink_states: - sdfg.add_edge(s, new_sink, InterstateEdge()) + # Create a unique sink block to avoid issues with finding control flow. + sink_nodes = cfg.sink_nodes() + if len(sink_nodes) > 1: + new_sink = cfg.add_state('common_sink') + for s in sink_nodes: + cfg.add_edge(s, new_sink, InterstateEdge()) - ipostdom = utils.postdominators(sdfg) - cft = cf.structured_control_flow_tree(sdfg, None) + ipostdom = utils.postdominators(cfg) + if cfg.root_sdfg.using_experimental_blocks: + cft = cf.structured_control_flow_tree_with_regions(cfg, None) + else: + cft = cf.structured_control_flow_tree(cfg, None) - # Iterate over the SDFG's control flow scopes and create for each an SDFG subraph. These subgraphs must be disjoint, - # so we duplicate SDFGStates that appear in more than one scopes (guards and exits of loops and conditionals). - components = {} - visited = {} # Dict[SDFGState, bool]: True if SDFGState in Scope (non-SingleState) + # Iterate over the graph's control flow scopes and create for each a subraph. These subgraphs must be disjoint, + # so we duplicate blocks that appear in more than one scopes (guards and exits of loops and conditionals). + components: Dict[ControlFlowBlock, Tuple[Set[ControlFlowBlock], ControlFlowBlock]] = {} + visited: Dict[ControlFlowBlock, bool] = {} # Block -> True if block in Scope (non-SingleState) for i, child in enumerate(cft.children): if isinstance(child, cf.BasicCFBlock): if child.state in visited: @@ -281,18 +284,18 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: guard = child.guard fexit = None condition = child.condition if isinstance(child, cf.ForScope) else child.test - for e in sdfg.out_edges(guard): + for e in cfg.out_edges(guard): if e.data.condition != condition: fexit = e.dst break if fexit is None: raise ValueError("Cannot find for-scope's exit states.") - states = set(utils.dfs_conditional(sdfg, [guard], lambda p, _: p is not fexit)) + states = set(utils.dfs_conditional(cfg, [guard], lambda p, _: p is not fexit)) if guard in visited: if visited[guard]: - guard_copy = _copy_state(sdfg, guard, False, states) + guard_copy = _copy_state(cfg, guard, False, states) guard.remove_nodes_from(guard.nodes()) states.remove(guard) states.add(guard_copy) @@ -303,7 +306,7 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: if not (i == len(cft.children) - 2 and isinstance(cft.children[i + 1], cf.BasicCFBlock) and cft.children[i + 1].state is fexit): - fexit_copy = _copy_state(sdfg, fexit, True, states) + fexit_copy = _copy_state(cfg, fexit, True, states) fexit.remove_nodes_from(fexit.nodes()) states.remove(fexit) states.add(fexit_copy) @@ -314,11 +317,11 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: guard = child.branch_block ifexit = ipostdom[guard] - states = set(utils.dfs_conditional(sdfg, [guard], lambda p, _: p is not ifexit)) + states = set(utils.dfs_conditional(cfg, [guard], lambda p, _: p is not ifexit)) if guard in visited: if visited[guard]: - guard_copy = _copy_state(sdfg, guard, False, states) + guard_copy = _copy_state(cfg, guard, False, states) guard.remove_nodes_from(guard.nodes()) states.remove(guard) states.add(guard_copy) @@ -329,7 +332,7 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: if not (i == len(cft.children) - 2 and isinstance(cft.children[i + 1], cf.BasicCFBlock) and cft.children[i + 1].state is ifexit): - ifexit_copy = _copy_state(sdfg, ifexit, True, states) + ifexit_copy = _copy_state(cfg, ifexit, True, states) ifexit.remove_nodes_from(ifexit.nodes()) states.remove(ifexit) states.add(ifexit_copy) diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 7c89dd9b79..914dc2b3e9 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -2,7 +2,6 @@ from collections import defaultdict -import blinker from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index abe305f12c..d6e1f4c460 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -31,6 +31,9 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) for region in cfgs: + if isinstance(region, ConditionalBlock): + continue + sinks = region.sink_nodes() dummy_exit = region.add_state('__DACE_DUMMY') for s in sinks: diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 0be40f43d3..38778cba2b 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -43,6 +43,7 @@ def test_serialization(): for j in range(10): cfg = ControlFlowRegion(f'cfg_{j}', sdfg) + cfg.add_state('noop') cond_region.add_branch(CodeBlock(f'i == {j}'), cfg) assert sdfg.is_valid() diff --git a/tests/sdfg/free_symbols_test.py b/tests/sdfg/free_symbols_test.py index 3d162203d1..51cd739000 100644 --- a/tests/sdfg/free_symbols_test.py +++ b/tests/sdfg/free_symbols_test.py @@ -55,7 +55,7 @@ def test_sdfg(): sdfg: dace.SDFG = fsymtest_multistate.to_sdfg() sdfg.simplify() # Test each state separately - for state in sdfg.nodes(): + for state in sdfg.states(): assert (state.free_symbols == {'k', 'N', 'M', 'L'} or state.free_symbols == set()) # The SDFG itself should have another free symbol assert sdfg.free_symbols == {'K', 'M', 'N', 'L'} @@ -67,7 +67,7 @@ def test_constants(): sdfg.add_constant('K', 5) sdfg.add_constant('L', 20) - for state in sdfg.nodes(): + for state in sdfg.states(): assert (state.free_symbols == {'k', 'N', 'M'} or state.free_symbols == set()) assert sdfg.free_symbols == {'M', 'N'} From 99225a5f0c7b41056e5297dfa7e82f71b5ba66a4 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 15 Oct 2024 13:14:03 +0200 Subject: [PATCH 044/108] Adapt reference reduction pass --- dace/transformation/helpers.py | 99 ++++++++++++++++++- .../passes/analysis/analysis.py | 39 ++++++++ .../passes/reference_reduction.py | 15 ++- 3 files changed, 142 insertions(+), 11 deletions(-) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index c69ba5a149..19bea0f0e8 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -7,12 +7,12 @@ from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.subsets import Range, Subset, union import dace.subsets as subsets -from typing import Dict, List, Optional, Tuple, Set, Union +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set, Union from dace import data, dtypes, symbolic from dace.codegen import control_flow as cf from dace.sdfg import nodes, utils -from dace.sdfg.graph import SubgraphView, MultiConnectorEdge +from dace.sdfg.graph import Edge, SubgraphView, MultiConnectorEdge from dace.sdfg.scope import ScopeSubgraphView, ScopeTree from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.sdfg import graph @@ -1555,3 +1555,98 @@ def make_map_internal_write_external(sdfg: SDFG, state: SDFGState, map_exit: nod memlet=Memlet(data=sink.data, subset=copy.deepcopy(subset), other_subset=copy.deepcopy(subset))) + + +def all_isedges_between(src: ControlFlowBlock, dst: ControlFlowBlock) -> Iterable[Edge[InterstateEdge]]: + """ + Helper function that generates an iterable of all edges potentially encountered between two control flow blocks. + """ + if src.sdfg is not dst.sdfg: + raise RuntimeError('Blocks reside in different SDFGs') + + if src.parent_graph is dst.parent_graph: + # Simple case where both blocks reside in the same graph: + edges = set() + for p in src.parent_graph.all_simple_paths(src, dst, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion): + edges.update(e.dst.all_interstate_edges()) + return edges + else: + # In the case where the two blocks are not in the same graph, we follow this procedure: + # 1. Collect the list of control flow regions on the direct path between the source and destination: + # a) Determine the 'lowest common parent' region + # b) Determine the list of parents of the source before the common parent is reached + # c) Determine the list of parents of the destination before the common parent is reached. + # 2. In each of the parents of the source, add all edges from the source or the next parent until the + # end(s) of each region to the result + # 3. In each of the destination's parents, add all edges from the start block on until the destination + # or next parent to the result. + # 4. In the lowest common parent region, find all edge paths between the next parent regions for both + # the source and destination. + # Note that for each edge, if the destination is a control flow region, any edges inside of it may also + # be on the path and consequently also need to be added. + edges = set() + + # Step 1.a): Find the lowest common parent region. + common_regions = set() + pivot_graph = src.parent_graph + all_parent_regions_src = [pivot_graph] + while not isinstance(pivot_graph, SDFG): + pivot_graph = pivot_graph.parent_graph + all_parent_regions_src.append(pivot_graph) + pivot_graph = dst.parent_graph + all_parent_regions_dst = [pivot_graph] + while not isinstance(pivot_graph, SDFG): + pivot_graph = pivot_graph.parent_graph + all_parent_regions_dst.append(pivot_graph) + if pivot_graph in all_parent_regions_src: + common_regions.add(pivot_graph) + + # Step 1.b) and 1.c): Determine the list of parents involved in the path for the source and destination. + involved_src: List[ControlFlowRegion] = [] + involved_dst: List[ControlFlowRegion] = [] + common_parent: ControlFlowRegion = None + for r in all_parent_regions_src: + if r not in common_regions: + involved_src.append(r) + else: + common_parent = r + break + for r in all_parent_regions_dst: + if r not in common_regions: + involved_dst.append(r) + else: + if r is not common_parent: + raise RuntimeError('No common parent found') + break + + # Step 2 + src_pivot = src + for r in involved_src: + for sink in r.sink_nodes(): + for p in r.all_simple_paths(src_pivot, sink, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion): + edges.update(e.dst.all_interstate_edges()) + src_pivot = r + # Step 3 + dst_pivot = dst + for r in involved_dst: + for p in r.all_simple_paths(r.start_block, dst_pivot, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion): + edges.update(e.dst.all_interstate_edges()) + dst_pivot = r + + # Step 4 + for p in common_parent.all_simple_paths(src_pivot, dst_pivot, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion) and not e.dst is dst_pivot: + edges.update(e.dst.all_interstate_edges()) + + return edges diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 914dc2b3e9..fae2ae2169 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -17,6 +17,45 @@ SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], ControlFlowBlock]]]] +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class InterstateEdgeReachability(ppl.Pass): + """ + Evaluates which interstate edges can be executed after each control flow block. + """ + + CATEGORY: str = 'Analysis' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + # If anything was modified, reapply + return modified & ppl.Modifies.CFG + + def depends_on(self): + return {ControlFlowBlockReachability} + + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: + """ + :return: A dictionary mapping each state to its other reachable states. + """ + # Ensure control flow block reachability is run if not run within a pipeline. + if pipeline_res is None or not ControlFlowBlockReachability.__name__ in pipeline_res: + cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) + else: + cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] + reachable: Dict[int, Dict[ControlFlowBlock, Set[Edge[InterstateEdge]]]] = {} + for sdfg in top_sdfg.all_sdfgs_recursive(): + result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) + for state in sdfg.states(): + for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: + if isinstance(reached, SDFGState): + result[state].add(reached) + reachable[sdfg.cfg_id] = result + return reachable + + @properties.make_properties @transformation.experimental_cfg_block_compatible class StateReachability(ppl.Pass): diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index dc5ae1eb7d..a04cd89e77 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -5,13 +5,13 @@ from dace import SDFG, SDFGState, data, properties, Memlet from dace.sdfg import nodes -from dace.sdfg.analysis import cfg from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.helpers import all_isedges_between from dace.transformation.passes import analysis as ap @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class ReferenceToView(ppl.Pass): """ Replaces Reference data descriptors that are only set to one source with views. @@ -135,13 +135,10 @@ def find_candidates( # Filter self and unreachable states if other_state is state or other_state not in reachable_states[state]: continue - for path in sdfg.all_simple_paths(state, other_state, as_edges=True): - for e in path: - # The symbol was modified/reassigned in one of the paths, skip - if fsyms & e.data.assignments.keys(): - result.remove(cand) - break - if cand not in result: + for e in all_isedges_between(state, other_state): + # The symbol was modified/reassigned in one of the paths, skip + if fsyms & e.data.assignments.keys(): + result.remove(cand) break if cand not in result: break From d6c7c8b7405c926eff87b7e7a5068ac4aa73528e Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 16 Oct 2024 11:52:01 +0200 Subject: [PATCH 045/108] Adapt constant propagation --- dace/codegen/targets/framecode.py | 6 - dace/frontend/python/astutils.py | 35 +- dace/sdfg/analysis/cfg.py | 9 +- dace/sdfg/replace.py | 66 ++-- dace/sdfg/state.py | 16 +- .../transformation/interstate/block_fusion.py | 2 +- dace/transformation/pass_pipeline.py | 2 +- .../passes/analysis/analysis.py | 2 +- .../passes/analysis/loop_analysis.py | 25 +- .../passes/constant_propagation.py | 312 ++++++++++++++---- .../passes/dead_state_elimination.py | 10 + tests/passes/constant_propagation_test.py | 30 +- 12 files changed, 354 insertions(+), 161 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index d71ea40fee..62dc828590 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -919,12 +919,6 @@ def generate_code(self, global_symbols.update(symbols) if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: - init_assignment = cfr.init_statement.code[0] - update_assignment = cfr.update_statement.code[0] - if isinstance(init_assignment, astutils.ast.Assign): - init_assignment = init_assignment.value - if isinstance(update_assignment, astutils.ast.Assign): - update_assignment = update_assignment.value if not cfr.loop_variable in interstate_symbols: l_end = loop_analysis.get_loop_end(cfr) l_start = loop_analysis.get_init_assignment(cfr) diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 425e94cd9f..4e6aa68651 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -1,9 +1,8 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various AST parsing utilities for DaCe. """ import ast import astunparse import copy -from collections import OrderedDict from io import StringIO import inspect import numbers @@ -12,7 +11,7 @@ import sys from typing import Any, Dict, List, Optional, Set, Union -from dace import dtypes, symbolic +from dace import symbolic if sys.version_info >= (3, 8): @@ -587,6 +586,36 @@ def visit_keyword(self, node: ast.keyword): return self.generic_visit(node) +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = unparse(node.value) + return self.generic_visit(node) + + +class ASTReplaceAssignmentRHS(ast.NodeVisitor): + + repl_visitor: ASTFindReplace + + def __init__(self, repl: Dict[str, str]): + self.repl_visitor = ASTFindReplace(repl) + + def visit_Assign(self, node: ast.Assign) -> Any: + self.repl_visitor.visit(node.value) + return self.generic_visit(node) + + class RemoveSubscripts(ast.NodeTransformer): def __init__(self, keywords: Set[str]): self.keywords = keywords diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index c96ef5aff0..eb9aea0e2b 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -377,10 +377,11 @@ def blockorder_topological_sort(cfg: ControlFlowRegion, elif isinstance(block, ConditionalBlock): if not ignore_nonstate_blocks: yield block - for _, branch in block.branches: - if not ignore_nonstate_blocks: - yield branch - yield from blockorder_topological_sort(branch, recursive, ignore_nonstate_blocks) + if recursive: + for _, branch in block.branches: + if not ignore_nonstate_blocks: + yield branch + yield from blockorder_topological_sort(branch, recursive, ignore_nonstate_blocks) elif isinstance(block, SDFGState): yield block else: diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index e3bea0b807..9b6086098e 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -1,9 +1,9 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains functionality to perform find-and-replace of symbols in SDFGs. """ import re import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional import sympy as sp @@ -96,6 +96,38 @@ def replace(subgraph: 'StateSubgraphView', name: str, new_name: str): replace_dict(subgraph, {name: new_name}) +def replace_in_codeblock(codeblock: properties.CodeBlock, repl: Dict[str, str], node: Optional[Any] = None): + code = codeblock.code + if isinstance(code, str) and code: + lang = codeblock.language + if lang is dtypes.Language.CPP: # Replace in C++ code + prefix = '' + tokenized = tokenize_cpp.findall(code) + active_replacements = set() + for name, new_name in repl.items(): + if name not in tokenized: + continue + # Use local variables and shadowing to replace + replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' + prefix = replacement + prefix + active_replacements.add(name) + + if prefix: + codeblock.code = prefix + code + if node and isinstance(node, dace.nodes.Tasklet): + # Ignore replaced symbols since they no longer exist as reads + node.ignored_symbols = node.ignored_symbols.union(active_replacements) + + else: + warnings.warn('Replacement of %s with %s was not made ' + 'for string tasklet code of language %s' % (name, new_name, lang)) + + elif codeblock.code is not None: + afr = ASTFindReplace(repl) + for stmt in codeblock.code: + afr.visit(stmt) + + def replace_properties_dict(node: Any, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): @@ -127,35 +159,7 @@ def replace_properties_dict(node: Any, if hasattr(node, 'in_connectors'): reduced_repl -= set(node.in_connectors.keys()) | set(node.out_connectors.keys()) reduced_repl = {k: repl[k] for k in reduced_repl} - code = propval.code - if isinstance(code, str) and code: - lang = propval.language - if lang is dtypes.Language.CPP: # Replace in C++ code - prefix = '' - tokenized = tokenize_cpp.findall(code) - active_replacements = set() - for name, new_name in reduced_repl.items(): - if name not in tokenized: - continue - # Use local variables and shadowing to replace - replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' - prefix = replacement + prefix - active_replacements.add(name) - - if prefix: - propval.code = prefix + code - if isinstance(node, dace.nodes.Tasklet): - # Ignore replaced symbols since they no longer exist as reads - node.ignored_symbols = node.ignored_symbols.union(active_replacements) - - else: - warnings.warn('Replacement of %s with %s was not made ' - 'for string tasklet code of language %s' % (name, new_name, lang)) - - elif propval.code is not None: - afr = ASTFindReplace(reduced_repl) - for stmt in propval.code: - afr.visit(stmt) + replace_in_codeblock(propval, reduced_repl, node) elif (isinstance(propclass, properties.DictProperty) and pname == 'symbol_mapping'): # Symbol mappings for nested SDFGs for symname, sym_mapping in propval.items(): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c6bad25a26..a5e8efee5a 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -13,6 +13,7 @@ import dace from dace.frontend.python import astutils +from dace.sdfg.replace import replace_in_codeblock import dace.serialize from dace import data as dt from dace import dtypes @@ -2766,16 +2767,19 @@ def add_state_after(self, ################################################################### # Traversal methods - def all_control_flow_regions(self, recursive=False) -> Iterator['AbstractControlFlowRegion']: + def all_control_flow_regions(self, recursive=False, parent_first=True) -> Iterator['AbstractControlFlowRegion']: """ Iterate over this and all nested control flow regions. """ - yield self + if parent_first: + yield self for block in self.nodes(): if isinstance(block, SDFGState) and recursive: for node in block.nodes(): if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_control_flow_regions(recursive=recursive) + yield from node.sdfg.all_control_flow_regions(recursive=recursive, parent_first=parent_first) elif isinstance(block, AbstractControlFlowRegion): - yield from block.all_control_flow_regions(recursive=recursive) + yield from block.all_control_flow_regions(recursive=recursive, parent_first=parent_first) + if not parent_first: + yield self def all_sdfgs_recursive(self) -> Iterator['SDFG']: """ Iterate over this and all nested SDFGs. """ @@ -3313,8 +3317,10 @@ def replace_dict(self, from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) - for _, region in self._branches: + for cond, region in self._branches: region.replace_dict(repl, symrepl, replace_in_graph) + if cond is not None: + replace_in_codeblock(cond, repl) def to_json(self, parent=None): json = super().to_json(parent) diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py index a890551ba1..6abd65fc87 100644 --- a/dace/transformation/interstate/block_fusion.py +++ b/dace/transformation/interstate/block_fusion.py @@ -81,7 +81,7 @@ def apply(self, graph: ControlFlowRegion, sdfg): connecting_edge = graph.edges_between(self.first_block, self.second_block)[0] assignments_to_absorb = connecting_edge.data.assignments graph.remove_edge(connecting_edge) - for ie in graph.out_edges(self.first_block): + for ie in graph.in_edges(self.first_block): if assignments_to_absorb: ie.data.assignments.update(assignments_to_absorb) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 3689b3f4ac..9651e7d208 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -283,7 +283,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D if nothing was returned. """ result = {} - for region in sdfg.all_control_flow_regions(recursive=True): + for region in sdfg.all_control_flow_regions(recursive=True, parent_first=False): retval = self.apply(region, pipeline_results) if retval is not None: result[region.cfg_id] = retval diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index fae2ae2169..447efea42c 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -225,7 +225,7 @@ def reachable_nodes(G): @transformation.experimental_cfg_block_compatible class SymbolAccessSets(ppl.ControlFlowRegionPass): """ - Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). + Evaluates symbol access sets (which symbols are read/written in each control flow block or interstate edge). """ CATEGORY: str = 'Analysis' diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 3d15f73c73..ec9d4d0c73 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -3,8 +3,7 @@ Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes. """ -import ast -from typing import Any, Dict, Optional +from typing import Dict, Optional from dace.frontend.python import astutils import sympy @@ -13,24 +12,6 @@ from dace.sdfg.state import LoopRegion -class FindAssignment(ast.NodeVisitor): - - assignments: Dict[str, str] - multiple: bool - - def __init__(self): - self.assignments = {} - self.multiple = False - - def visit_Assign(self, node: ast.Assign) -> Any: - for tgt in node.targets: - if isinstance(tgt, ast.Name): - if tgt.id in self.assignments: - self.multiple = True - self.assignments[tgt.id] = astutils.unparse(node.value) - return self.generic_visit(node) - - def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: """ Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). @@ -68,7 +49,7 @@ def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] assignments: Dict[str, str] = {} for code in init_codes_list: - visitor = FindAssignment() + visitor = astutils.FindAssignment() visitor.visit(code) if visitor.multiple: return None @@ -94,7 +75,7 @@ def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] assignments: Dict[str, str] = {} for code in update_codes_list: - visitor = FindAssignment() + visitor = astutils.FindAssignment() visitor.visit(code) if visitor.multiple: return None diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index bfa0928415..f6311dea6f 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -1,11 +1,12 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from dataclasses import dataclass from dace.frontend.python import astutils -from dace.sdfg.analysis import cfg +from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.sdfg import InterstateEdge -from dace.sdfg import nodes, utils as sdutil +from dace.sdfg import nodes +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.transformation import pass_pipeline as ppl, transformation from dace.cli.progress import optional_progressbar from dace import data, SDFG, SDFGState, dtypes, symbolic, properties @@ -17,9 +18,13 @@ class _UnknownValue: pass +ConstsT = Dict[str, Any] +BlockConstsT = Dict[ControlFlowBlock, ConstsT] + + @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class ConstantPropagation(ppl.Pass): """ Propagates constants and symbols that were assigned to one value forward through the SDFG, reducing @@ -42,7 +47,7 @@ def should_apply(self, sdfg: SDFG) -> bool: """ Fast check (O(m)) whether the pass should early-exit without traversing the SDFG. """ - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): # If there are no assignments, there are no constants to propagate if len(edge.data.assignments) == 0: continue @@ -69,8 +74,28 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = if not initial_symbols and not self.should_apply(sdfg): result = {} else: - # Trace all constants and symbols through states - per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants(sdfg, initial_symbols) + arrays: Set[str] = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) + + # Add nested data to arrays + def _add_nested_datanames(name: str, desc: data.Structure): + for k, v in desc.members.items(): + if isinstance(v, data.Structure): + _add_nested_datanames(f'{name}.{k}', v) + elif isinstance(v, data.ContainerArray): + # TODO: How are we handling this? + pass + arrays.add(f'{name}.{k}') + + for name, desc in sdfg.arrays.items(): + if isinstance(desc, data.Structure): + _add_nested_datanames(name, desc) + + # Trace all constants and symbols through blocks + in_consts: BlockConstsT = { sdfg: initial_symbols } + pre_consts: BlockConstsT = {} + post_consts: BlockConstsT = {} + out_consts: BlockConstsT = {} + self._collect_constants_for_region(sdfg, arrays, in_consts, pre_consts, post_consts, out_consts) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} @@ -78,13 +103,14 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated - desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, per_state_constants) + desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_consts) # Replace constants per state - for state, mapping in optional_progressbar(per_state_constants.items(), - 'Propagating constants', - n=len(per_state_constants), + for block, mapping in optional_progressbar(in_consts.items(), 'Propagating constants', n=len(in_consts), progress=self.progress): + if block is sdfg: + continue + remaining_unknowns.update( {k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols}) @@ -92,17 +118,36 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } - if not mapping: - continue - - # Update replaced symbols for later replacements - symbols_replaced.update(mapping) - # Replace in state contents - state.replace_dict(mapping) - # Replace in outgoing edges as well - for e in sdfg.out_edges(state): - e.data.replace_dict(mapping, replace_keys=False) + if mapping: + # Update replaced symbols for later replacements + symbols_replaced.update(mapping) + + if isinstance(block, SDFGState): + # Replace in state contents + block.replace_dict(mapping) + elif isinstance(block, AbstractControlFlowRegion): + block.replace_dict(mapping, replace_in_graph=False, replace_keys=False) + + # Replace in outgoing edges as well + for e in block.parent_graph.out_edges(block): + e.data.replace_dict(mapping, replace_keys=False) + + if isinstance(block, LoopRegion): + if block in post_consts and post_consts[block] is not None: + if block.update_statement is not None and (block.inverted and block.update_before_condition or + not block.inverted): + # Replace the RHS of the update experssion + post_mapping = { + k: v + for k, v in post_consts[block].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols + } + update_stmt = block.update_statement + updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + for update in updates: + astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update) + block.update_statement.code = updates # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} @@ -114,7 +159,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = replace_keys=False) # Remove constant symbol assignments in interstate edges - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: del edge.data.assignments[sym] @@ -134,7 +179,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = sid = sdfg.cfg_id result = set((sid, sym) for sym in result) - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): nested_id = node.sdfg.cfg_id @@ -155,59 +200,139 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = def report(self, pass_retval: Set[str]) -> str: return f'Propagated {len(pass_retval)} constants.' - def collect_constants(self, - sdfg: SDFG, - initial_symbols: Optional[Dict[str, Any]] = None) -> Dict[SDFGState, Dict[str, Any]]: + def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arrays: Set[str], + in_const_dict: BlockConstsT, pre_const_dict: BlockConstsT, + post_const_dict: BlockConstsT, out_const_dict: BlockConstsT) -> None: """ - Finds all constants and constant-assigned symbols in the SDFG for each state. - - :param sdfg: The SDFG to traverse. - :param initial_symbols: If not None, sets values of initial symbols. - :return: A dictionary mapping an SDFG state to a mapping of constants and their corresponding values. + Collect the constants for and inside of a conditional region. + Recursively collects constants inside of nested regions. + + :param conditional: The conditional region to traverse. + :param arrays: A set of data descriptors in the SDFG. + :param in_const_dict: Dictionary mapping each control flow block to the set of constants observed right before + the block is executed. Populated by this function. + :param pre_const_dict: Dictionary mapping each control flow block to the set of constants observed before its + contents are executed. Populated by this function. + :param post_const_dict: Dictionary mapping each control flow block to the set of constants observed after its + contents are executed. Populated by this function. + :param out_const_dict: Dictionary mapping each control flow block to the set of constants observed right after + the block is executed. Populated by this function. + """ + in_consts = in_const_dict[conditional] + # First, collect all constants for each of the branches. + for _, branch in conditional.branches: + in_const_dict[branch] = in_consts + self._collect_constants_for_region(branch, arrays, in_const_dict, pre_const_dict, post_const_dict, + out_const_dict) + # Second, determine the 'post constants' (constants at the end of the conditional region) as an intersection + # between the output constants of each of the branches. + post_consts = {} + post_consts_intersection = None + has_else = False + for cond, branch in conditional.branches: + if post_consts_intersection is None: + post_consts_intersection = set(out_const_dict[branch].keys()) + else: + post_consts_intersection &= set(out_const_dict[branch].keys()) + if cond is None: + has_else = True + for _, branch in conditional.branches: + for k, v in out_const_dict[branch].items(): + if k in post_consts_intersection: + if k not in post_consts: + post_consts[k] = v + elif post_consts[k] != _UnknownValue and post_consts[k] != v: + post_consts[k] = _UnknownValue + else: + post_consts[k] = _UnknownValue + post_const_dict[conditional] = post_consts + + # Finally, determine the conditional region's output constants. + if has_else: + # If there is an else, at least one branch will certainly be taken, so the output constants are the region's + # post constants. + out_const_dict[conditional] = post_consts + else: + # No else branch is present, so it is possible that no branch is executed. In this case the out constants + # are the intersection between the in constants and the post constants. + out_consts = in_consts + for k, v in post_consts.items(): + if k not in out_consts: + out_consts[k] = _UnknownValue + elif out_consts[k] != _UnknownValue and out_consts[k] != v: + out_consts[k] = _UnknownValue + out_const_dict[conditional] = out_consts + + def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str], in_const_dict: BlockConstsT, + pre_const_dict: BlockConstsT, post_const_dict: BlockConstsT, + out_const_dict: BlockConstsT) -> None: """ - arrays: Set[str] = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) - result: Dict[SDFGState, Dict[str, Any]] = {} - - # Add nested data to arrays - def _add_nested_datanames(name: str, desc: data.Structure): - for k, v in desc.members.items(): - if isinstance(v, data.Structure): - _add_nested_datanames(f'{name}.{k}', v) - elif isinstance(v, data.ContainerArray): - # TODO: How are we handling this? - pass - arrays.add(f'{name}.{k}') - - for name, desc in sdfg.arrays.items(): - if isinstance(desc, data.Structure): - _add_nested_datanames(name, desc) + Finds all constants and constant-assigned symbols in the control flow graph for each block. + Recursively collects constants for nested control flow regions. + + :param cfg: The CFG to traverse. + :param arrays: A set of data descriptors in the SDFG. + :param in_const_dict: Dictionary mapping each control flow block to the set of constants observed right before + the block is executed. Populated by this function. + :param pre_const_dict: Dictionary mapping each control flow block to the set of constants observed before its + contents are executed. Populated by this function. + :param post_const_dict: Dictionary mapping each control flow block to the set of constants observed after its + contents are executed. Populated by this function. + :param out_const_dict: Dictionary mapping each control flow block to the set of constants observed right after + the block is executed. Populated by this function. + """ + # Given the 'in constants', i.e., the constants for before the current region is executed, compute the 'pre + # constants', i.e., the set of constants seen inside the region when executing. + if cfg in in_const_dict: + in_const = in_const_dict[cfg] + if isinstance(cfg, LoopRegion): + # In the case of a loop, the 'pre constants' are equivalent to the 'in constants', with the exception + # of values that may at any point be re-assigned inside the loop, since that assignment would carry over + # into the next iteration (including increments to the loop variable, if present). + assignments_within = set() + for e in cfg.all_interstate_edges(): + for k in e.data.assignments.keys(): + assignments_within.add(k) + if cfg.loop_variable is not None: + assignments_within.add(cfg.loop_variable) + pre_const = { k: (v if k not in assignments_within else _UnknownValue) for k, v in in_const.items() } + else: + # In any other case, the 'pre constants' are equivalent to the 'in constants'. + pre_const = {} + pre_const.update(in_const) + else: + # No 'in constants' for the current region - so initialize to nothing. + pre_const = {} + pre_const_dict[cfg] = pre_const + in_const = {} + pre_const_dict[cfg] = pre_const # Process: - # * Collect constants in topologically ordered states + # * Collect constants in topologically ordered blocks # * Propagate forward symbols forward and edge assignments # * If value is ambiguous (not the same), set value to UNKNOWN # * Repeat until no update is performed - start_state = sdfg.start_state - if initial_symbols: - result[start_state] = {} - result[start_state].update(initial_symbols) + start_block = cfg.start_block + if pre_const: + in_const_dict[start_block] = {} + in_const_dict[start_block].update(pre_const) redo = True while redo: redo = False - # Traverse SDFG topologically - for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants', - sdfg.number_of_nodes(), self.progress): - + # Traverse CFG topologically + for block in optional_progressbar(cfg_analysis.blockorder_topological_sort(cfg, recursive=False), + 'Collecting constants for ' + cfg.label, cfg.number_of_nodes(), + self.progress): # Get predecessors - in_edges = sdfg.in_edges(state) + in_edges = cfg.in_edges(block) assignments = {} for edge in in_edges: # If source was already visited, use its propagated constants constants: Dict[str, Any] = {} - if edge.src in result: - constants.update(result[edge.src]) + if edge.src in out_const_dict: + constants.update(out_const_dict[edge.src]) # Update constants with incoming edge self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) @@ -217,16 +342,15 @@ def _add_nested_datanames(name: str, desc: data.Structure): # If a symbol appearing in the replacing expression of a constant is modified, # the constant is not valid anymore if ((aname in assignments and aval != assignments[aname]) or - symbolic.free_symbols_and_functions(aval) & edge.data.assignments.keys()): + symbolic.free_symbols_and_functions(aval) & edge.data.assignments.keys()): assignments[aname] = _UnknownValue else: assignments[aname] = aval - for edge in sdfg.out_edges(state): + for edge in cfg.out_edges(block): for aname, aval in assignments.items(): - # If the specific replacement would result in the value - # being both used and reassigned on the same inter-state - # edge, remove it from consideration. + # If the specific replacement would result in the value being both used and reassigned on the + # same inter-state edge, remove it from consideration. replacements = symbolic.free_symbols_and_functions(aval) used_in_assignments = { k @@ -235,13 +359,59 @@ def _add_nested_datanames(name: str, desc: data.Structure): reassignments = replacements & edge.data.assignments.keys() if reassignments and (used_in_assignments - reassignments): assignments[aname] = _UnknownValue - - if state not in result: # Condition may evaluate to False when state is the start-state - result[state] = {} - redo |= self._propagate(result[state], assignments) - - return result - + if block not in in_const_dict: + in_const_dict[block] = {} + if assignments: + redo |= self._propagate(in_const_dict[block], assignments) + + if isinstance(block, SDFGState): + # Simple case, no change in constants through this block. + pre_const_dict[block] = in_const_dict[block] + post_const_dict[block] = in_const_dict[block] + out_const_dict[block] = in_const_dict[block] + elif isinstance(block, ControlFlowRegion): + self._collect_constants_for_region(block, arrays, in_const_dict, pre_const_dict, post_const_dict, + out_const_dict) + elif isinstance(block, ConditionalBlock): + self._collect_constants_for_conditional(block, arrays, in_const_dict, pre_const_dict, + post_const_dict, out_const_dict) + + # For all sink nodes, compute the overlapping set of constants between them, making sure all constants in the + # resulting intersection are actually constants (i.e., all blocks see the same constant value for them). This + # resulting overlap forms the 'post constants' of this CFG. + post_consts = {} + post_consts_intersection = None + sinks = cfg.sink_nodes() + for sink in sinks: + if post_consts_intersection is None: + post_consts_intersection = set(out_const_dict[sink].keys()) + else: + post_consts_intersection &= set(out_const_dict[sink].keys()) + for sink in sinks: + for k, v in out_const_dict[sink].items(): + if k in post_consts_intersection: + if k not in post_consts: + post_consts[k] = v + elif post_consts[k] != _UnknownValue and post_consts[k] != v: + post_consts[k] = _UnknownValue + else: + post_consts[k] = _UnknownValue + post_const_dict[cfg] = post_consts + + out_consts = {} + if isinstance(cfg, LoopRegion): + # For a loop we can not determine if it is being executed and how many times it would be executed. The 'out + # constants' are thus formed from the intersection of the loop's 'in constants' and 'post constants'. + out_consts.update(in_const) + for k, v in post_consts.items(): + if k not in out_consts: + out_consts[k] = _UnknownValue + elif out_consts[k] != _UnknownValue and out_consts[k] != v: + out_consts[k] = _UnknownValue + else: + out_consts.update(post_consts) + out_const_dict[cfg] = out_consts + def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[SDFGState, Dict[str, Any]]) -> Tuple[Set[str], Set[str]]: """ Finds constant symbols that data descriptors (e.g., arrays) depend on. diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 6a7e80fabf..cda193f43a 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -65,6 +65,16 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters for _, b in dead_branches: result.add(b) node.remove_branch(b) + # If only an 'else' is left over, inline it. + if len(node.branches) == 1 and node.branches[0][0] is None: + branch = node.branches[0][1] + node.parent_graph.add_node(branch) + for ie in cfg.in_edges(node): + cfg.add_edge(ie.src, branch, ie.data) + for oe in cfg.out_edges(node): + cfg.add_edge(branch, oe.dst, oe.data) + result.add(node) + cfg.remove_node(node) else: result.add(node) cfg.remove_node(block) diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index acb1033554..909e22a2b5 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -2,6 +2,7 @@ import pytest import dace +from dace.sdfg.state import LoopRegion from dace.transformation.passes.constant_propagation import ConstantPropagation, _UnknownValue from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion import numpy as np @@ -19,8 +20,6 @@ def program(A: dace.float64[20]): A[:] = cval + 4 sdfg = program.to_sdfg() - ScalarToSymbolPromotion().apply_pass(sdfg, {}) - ConstantPropagation().apply_pass(sdfg, {}) assert len(sdfg.symbols) == 0 for e in sdfg.edges(): @@ -41,8 +40,6 @@ def program(A: dace.int64[20]): A[l] = k sdfg = program.to_sdfg() - ScalarToSymbolPromotion().apply_pass(sdfg, {}) - ConstantPropagation().apply_pass(sdfg, {}) assert set(sdfg.symbols.keys()) == {'i'} @@ -66,10 +63,10 @@ def program(a: dace.float64[20]): a[0] = i # Use i - should be const sdfg = program.to_sdfg() - ScalarToSymbolPromotion().apply_pass(sdfg, {}) - ConstantPropagation().apply_pass(sdfg, {}) - assert set(sdfg.symbols.keys()) == {'i'} + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + assert node.loop_variable == 'i' # Test tasklets for node, _ in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.Tasklet): @@ -88,10 +85,10 @@ def program(a: dace.float64[20]): a[i] = i # Use i - not const sdfg = program.to_sdfg() - ScalarToSymbolPromotion().apply_pass(sdfg, {}) - ConstantPropagation().apply_pass(sdfg, {}) - assert set(sdfg.symbols.keys()) == {'i'} + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + assert node.loop_variable == 'i' # Test tasklets i_found = 0 @@ -115,10 +112,11 @@ def program(a: dace.float64[20, 20]): a[j, k] = 1 sdfg = program.to_sdfg() - ScalarToSymbolPromotion().apply_pass(sdfg, {}) - ConstantPropagation().apply_pass(sdfg, {}) - assert set(sdfg.symbols.keys()) == {'i', 'j'} + assert 'j' in sdfg.symbols + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + assert node.loop_variable == 'i' # Test memlet last_state = sdfg.sink_nodes()[0] @@ -143,8 +141,6 @@ def program(a: dace.float64[20, 20], scal: dace.int32): a[i, j] = 3 sdfg = program.to_sdfg() - ScalarToSymbolPromotion().apply_pass(sdfg, {}) - ConstantPropagation().apply_pass(sdfg, {}) assert len(sdfg.symbols.keys()) == 1 @@ -187,7 +183,9 @@ def test_complex_case(): sdfg.add_edge(usei, merge, dace.InterstateEdge(assignments={'j': 'j+1'})) sdfg.add_edge(merge, last, dace.InterstateEdge('j >= 2')) - propagated = ConstantPropagation().collect_constants(sdfg) #, reachability + propagated = {} + arrays = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) + ConstantPropagation()._collect_constants_for_region(sdfg, arrays, propagated, {}, {}, {}) assert len(propagated[init]) == 0 assert propagated[branch2]['i'] == '7' assert propagated[guard]['i'] is _UnknownValue From 607b09885553fca541bc560becd846dca63614ea Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 16 Oct 2024 13:30:01 +0200 Subject: [PATCH 046/108] Fix pytest arguments --- tests/passes/scalar_fission_test.py | 6 +++--- .../scalar_write_shadow_scopes_analysis_test.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/passes/scalar_fission_test.py b/tests/passes/scalar_fission_test.py index a1f8d1d20e..eeb959a926 100644 --- a/tests/passes/scalar_fission_test.py +++ b/tests/passes/scalar_fission_test.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize('with_raising', (False, True)) -def test_scalar_fission(with_raising = False): +def test_scalar_fission(with_raising): """ Test the scalar fission pass. This heavily relies on the scalar write shadow scopes pass, which is tested separately. @@ -114,7 +114,7 @@ def test_scalar_fission(with_raising = False): assert all([n.data == list(tmp2_edge.assignments.values())[0] for n in [tmp2_write, loop2_read_tmp]]) @pytest.mark.parametrize('with_raising', (False, True)) -def test_branch_subscopes_nofission(with_raising = False): +def test_branch_subscopes_nofission(with_raising): sdfg = dace.SDFG('branch_subscope_fission') sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [2], dace.int32) @@ -200,7 +200,7 @@ def test_branch_subscopes_nofission(with_raising = False): assert set(sdfg.arrays.keys()) == {'A', 'B', 'C'} @pytest.mark.parametrize('with_raising', (False, True)) -def test_branch_subscopes_fission(with_raising = False): +def test_branch_subscopes_fission(with_raising): sdfg = dace.SDFG('branch_subscope_fission') sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [2], dace.int32) diff --git a/tests/passes/scalar_write_shadow_scopes_analysis_test.py b/tests/passes/scalar_write_shadow_scopes_analysis_test.py index f0648cc2ba..78704bca60 100644 --- a/tests/passes/scalar_write_shadow_scopes_analysis_test.py +++ b/tests/passes/scalar_write_shadow_scopes_analysis_test.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize('with_raising', (False, True)) -def test_scalar_write_shadow_split(with_raising = False): +def test_scalar_write_shadow_split(with_raising): """ Test the scalar write shadow scopes pass with writes dominating reads across state. """ @@ -112,7 +112,7 @@ def test_scalar_write_shadow_split(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_scalar_write_shadow_fused(with_raising = False): +def test_scalar_write_shadow_fused(with_raising): """ Test the scalar write shadow scopes pass with writes dominating reads in the same state. """ @@ -196,7 +196,7 @@ def test_scalar_write_shadow_fused(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_scalar_write_shadow_interstate_self(with_raising = False): +def test_scalar_write_shadow_interstate_self(with_raising): """ Tests the scalar write shadow pass with interstate edge reads being shadowed by the state they're originating from. """ @@ -300,7 +300,7 @@ def test_scalar_write_shadow_interstate_self(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_scalar_write_shadow_interstate_pred(with_raising = False): +def test_scalar_write_shadow_interstate_pred(with_raising): """ Tests the scalar write shadow pass with interstate edge reads being shadowed by a predecessor state. """ @@ -408,7 +408,7 @@ def test_scalar_write_shadow_interstate_pred(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_loop_fake_shadow(with_raising = False): +def test_loop_fake_shadow(with_raising): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -460,7 +460,7 @@ def test_loop_fake_shadow(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_loop_fake_complex_shadow(with_raising = False): +def test_loop_fake_complex_shadow(with_raising): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -504,7 +504,7 @@ def test_loop_fake_complex_shadow(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_loop_real_shadow(with_raising = False): +def test_loop_real_shadow(with_raising): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -551,7 +551,7 @@ def test_loop_real_shadow(with_raising = False): @pytest.mark.parametrize('with_raising', (False, True)) -def test_dominationless_write_branch(with_raising = False): +def test_dominationless_write_branch(with_raising): sdfg = dace.SDFG('dominationless_write_branch') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) From 05b3e59ba4cc90675b67be1d611eb493829c250d Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 16 Oct 2024 16:11:14 +0200 Subject: [PATCH 047/108] Fixes --- dace/codegen/targets/framecode.py | 4 +- dace/sdfg/state.py | 6 +- .../interstate/move_loop_into_map.py | 3 +- .../passes/constant_propagation.py | 2 +- .../move_loop_into_map_test.py | 418 +++++++++--------- 5 files changed, 224 insertions(+), 209 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 62dc828590..769310c655 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -848,8 +848,8 @@ def deallocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: desc = node.desc(tsdfg) - self._dispatcher.dispatch_deallocate(tsdfg, cfg, state, state_id, node, desc, function_stream, - callsite_stream) + self._dispatcher.dispatch_deallocate(tsdfg, state.parent_graph, state, state_id, node, desc, + function_stream, callsite_stream) def generate_code(self, sdfg: SDFG, diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a5e8efee5a..2ff947e02c 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3209,12 +3209,12 @@ def replace_dict(self, replace_in_graph: bool = True, replace_keys: bool = True): if replace_keys: - from dace.sdfg.replace import replace_properties_dict - replace_properties_dict(self, repl, symrepl) - if self.loop_variable and self.loop_variable in repl: self.loop_variable = repl[self.loop_variable] + from dace.sdfg.replace import replace_properties_dict + replace_properties_dict(self, repl, symrepl) + super().replace_dict(repl, symrepl, replace_in_graph) def add_break(self, label=None) -> BreakBlock: diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index 5017ba3bea..3ac1f5a9e9 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -231,7 +231,8 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): from dace.transformation.interstate import RefineNestedAccess transformation = RefineNestedAccess() - transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) + transformation.setup_match(sdfg, body.parent_graph.cfg_id, body.block_id, + {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) transformation.apply(body, sdfg) # Second propagation for refined accesses. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index f6311dea6f..4e8da6ee42 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -126,7 +126,7 @@ def _add_nested_datanames(name: str, desc: data.Structure): if isinstance(block, SDFGState): # Replace in state contents block.replace_dict(mapping) - elif isinstance(block, AbstractControlFlowRegion): + elif isinstance(block, ControlFlowRegion): block.replace_dict(mapping, replace_in_graph=False, replace_keys=False) # Replace in outgoing edges as well diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index dca775bb7a..fd2d235dbc 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -1,7 +1,8 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + import dace +from dace.sdfg.state import LoopRegion from dace.transformation.interstate import MoveLoopIntoMap -import unittest import numpy as np I = dace.symbol("I") @@ -69,206 +70,219 @@ def apply_multiple_times_1(A: dace.float64[10, 10, 10, 10]): A[k, i, j, l] = k * 1000 + i * 100 + j * 10 + l -class MoveLoopIntoMapTest(unittest.TestCase): - - def semantic_eq(self, program): - A1 = np.random.rand(16, 16) - A2 = np.copy(A1) - - sdfg = program.to_sdfg(simplify=True) - sdfg(A1, I=A1.shape[0], J=A1.shape[1]) - - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertGreater(count, 0) - sdfg(A2, I=A2.shape[0], J=A2.shape[1]) - - self.assertTrue(np.allclose(A1, A2)) - - def test_forward_loops_semantic_eq(self): - self.semantic_eq(forward_loop) - - def test_backward_loops_semantic_eq(self): - self.semantic_eq(backward_loop) - - def test_multiple_edges(self): - self.semantic_eq(multiple_edges) - - def test_itervar_in_map_range(self): - sdfg = should_not_apply_1.to_sdfg(simplify=True) - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEqual(count, 0) - - def test_itervar_in_data(self): - sdfg = should_not_apply_2.to_sdfg(simplify=True) - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEqual(count, 0) - - def test_non_injective_index(self): - sdfg = should_not_apply_3.to_sdfg(simplify=True) - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEqual(count, 0) - - def test_apply_multiple_times(self): - sdfg = apply_multiple_times.to_sdfg(simplify=True) - overall = 0 - count = 1 - while (count > 0): - count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) - overall += count - sdfg.simplify() - - self.assertEqual(overall, 2) - - val = np.zeros((10, 10, 10), dtype=np.float64) - ref = val.copy() - - sdfg(A=val) - apply_multiple_times.f(ref) - - self.assertTrue(np.allclose(val, ref)) - - def test_apply_multiple_times_1(self): - sdfg = apply_multiple_times_1.to_sdfg(simplify=True) - overall = 0 - count = 1 - while (count > 0): - count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) - overall += count - sdfg.simplify() - - self.assertEqual(overall, 2) - - val = np.zeros((10, 10, 10, 10), dtype=np.float64) - ref = val.copy() - - sdfg(A=val) - apply_multiple_times_1.f(ref) - - self.assertTrue(np.allclose(val, ref)) - - def test_more_than_a_map(self): - """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """ - sdfg = dace.SDFG('more_than_a_map') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - oread = body.add_access('out') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(out=oread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) - - def test_more_than_a_map_1(self): - """ - `out` is written indirectly by the MapExit but is not read and, therefore, does not create a RW dependency. - """ - sdfg = dace.SDFG('more_than_a_map_1') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(A=aread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertTrue(count > 0) - - A = np.arange(9, dtype=np.float64).reshape(3, 3).copy() - B = np.arange(9, 18, dtype=np.float64).reshape(3, 3).copy() - val = np.empty((3, 3), dtype=np.float64) - sdfg(A=A, B=B, out=val) - - def reference(A, B): - for i in range(10): - tmp = A - B - out = tmp - return out - - ref = reference(A, B) - self.assertTrue(np.allclose(val, ref)) - - def test_more_than_a_map_2(self): - """ `out` is written indirectly by the MapExit with a subset dependent on the loop variable. This creates a RW - dependency. - """ - sdfg = dace.SDFG('more_than_a_map_2') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(A=aread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(twrite, owrite, dace.Memlet('out[k%3, (k+1)%3]', other_subset='(k+1)%3, k%3')) - sdfg.add_loop(None, body, None, 'k', '0', 'k < 10', 'k + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) - - def test_more_than_a_map_3(self): - """ There are more than one connected components in the loop body. The transformation should not apply. """ - sdfg = dace.SDFG('more_than_a_map_3') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(A=aread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - aread2 = body.add_access('A') - owrite2 = body.add_access('out') - body.add_nedge(aread2, owrite2, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) +def _semantic_eq(program): + A1 = np.random.rand(16, 16) + A2 = np.copy(A1) + + sdfg = program.to_sdfg(simplify=True) + sdfg(A1, I=A1.shape[0], J=A1.shape[1]) + + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count > 0 + sdfg(A2, I=A2.shape[0], J=A2.shape[1]) + + assert np.allclose(A1, A2) + +def test_forward_loops_semantic_eq(): + _semantic_eq(forward_loop) + +def test_backward_loops_semantic_eq(): + _semantic_eq(backward_loop) + +def test_multiple_edges(): + _semantic_eq(multiple_edges) + +def test_itervar_in_map_range(): + sdfg = should_not_apply_1.to_sdfg(simplify=True) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_itervar_in_data(): + sdfg = should_not_apply_2.to_sdfg(simplify=True) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_non_injective_index(): + sdfg = should_not_apply_3.to_sdfg(simplify=True) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_apply_multiple_times(): + sdfg = apply_multiple_times.to_sdfg(simplify=True) + overall = 0 + count = 1 + while (count > 0): + count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) + overall += count + sdfg.simplify() + + assert overall == 2 + + val = np.zeros((10, 10, 10), dtype=np.float64) + ref = val.copy() + + sdfg(A=val) + apply_multiple_times.f(ref) + + assert np.allclose(val, ref) + +def test_apply_multiple_times_1(): + sdfg = apply_multiple_times_1.to_sdfg(simplify=True) + overall = 0 + count = 1 + while (count > 0): + count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) + overall += count + sdfg.simplify() + + assert overall == 2 + + val = np.zeros((10, 10, 10, 10), dtype=np.float64) + ref = val.copy() + + sdfg(A=val) + apply_multiple_times_1.f(ref) + + assert np.allclose(val, ref) + +def test_more_than_a_map(): + """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """ + sdfg = dace.SDFG('more_than_a_map') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', '_ < 10', '_', '_ = 0', '_ = _ + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + oread = body.add_access('out') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(out=oread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_more_than_a_map_1(): + """ + `out` is written indirectly by the MapExit but is not read and, therefore, does not create a RW dependency. + """ + sdfg = dace.SDFG('more_than_a_map_1') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', '_ < 10', '_', '_ = 0', '_ = _ + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(A=aread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count > 0 + + A = np.arange(9, dtype=np.float64).reshape(3, 3).copy() + B = np.arange(9, 18, dtype=np.float64).reshape(3, 3).copy() + val = np.empty((3, 3), dtype=np.float64) + sdfg(A=A, B=B, out=val) + + def reference(A, B): + for i in range(10): + tmp = A - B + out = tmp + return out + + ref = reference(A, B) + assert np.allclose(val, ref) + +def test_more_than_a_map_2(): + """ `out` is written indirectly by the MapExit with a subset dependent on the loop variable. This creates a RW + dependency. + """ + sdfg = dace.SDFG('more_than_a_map_2') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', 'k < 10', 'k', 'k = 0', 'k = k + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(A=aread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(twrite, owrite, dace.Memlet('out[k%3, (k+1)%3]', other_subset='(k+1)%3, k%3')) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_more_than_a_map_3(): + """ There are more than one connected components in the loop body. The transformation should not apply. """ + sdfg = dace.SDFG('more_than_a_map_3') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', '_ < 10', '_', '_ = 0', '_ = _ + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(A=aread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + aread2 = body.add_access('A') + owrite2 = body.add_access('out') + body.add_nedge(aread2, owrite2, dace.Memlet.from_array('out', oarr)) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 if __name__ == '__main__': - unittest.main() + test_forward_loops_semantic_eq() + test_backward_loops_semantic_eq() + test_multiple_edges() + test_itervar_in_map_range() + test_itervar_in_data() + test_non_injective_index() + test_apply_multiple_times() + test_apply_multiple_times_1() + test_more_than_a_map() + test_more_than_a_map_1() + test_more_than_a_map_2() + test_more_than_a_map_3() From f5b617c6f76af8c0f5f013fa5de9422265351b60 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 16 Oct 2024 16:18:39 +0200 Subject: [PATCH 048/108] Fix invalid graph manipulation in test --- tests/constant_array_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/constant_array_test.py b/tests/constant_array_test.py index 69444768af..95e92b169f 100644 --- a/tests/constant_array_test.py +++ b/tests/constant_array_test.py @@ -112,12 +112,10 @@ def test(a: dace.float64[10]): sdfg = test.to_sdfg(simplify=False) sdfg.apply_transformations_repeated([StateFusion, RedundantArray, RedundantSecondArray]) - state = sdfg.node(0) # modify cst to be a dace constant: the python frontend adds an assignment tasklet - n = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == 'cst'][0] - for pred in state.predecessors(n): - state.remove_node(pred) + assign_state = sdfg.node(0) + sdfg.remove_node(assign_state) sdfg.add_constant('cst', 1.0, sdfg.arrays['cst']) From 33c928735e0d2c488ad5f5316ed6c3bdc3df8401 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 16 Oct 2024 17:15:38 +0200 Subject: [PATCH 049/108] Fixes --- .../codegen/instrumentation/data/data_dump.py | 66 +++++++------- dace/codegen/instrumentation/gpu_events.py | 80 ++++++++-------- dace/codegen/instrumentation/likwid.py | 44 ++++----- dace/codegen/instrumentation/papi.py | 91 ++++++++++--------- dace/codegen/instrumentation/provider.py | 39 +++++--- dace/codegen/instrumentation/timer.py | 42 ++++----- dace/codegen/targets/cpu.py | 4 +- dace/codegen/targets/cuda.py | 14 +-- dace/codegen/targets/framecode.py | 7 +- dace/codegen/targets/snitch.py | 4 +- dace/sdfg/utils.py | 7 +- .../passes/constant_propagation.py | 2 +- .../simplification/control_flow_raising.py | 1 + tests/codegen/allocation_lifetime_test.py | 2 +- tests/codegen/control_flow_detection_test.py | 3 +- tests/fortran/fortran_language_test.py | 16 +--- 16 files changed, 210 insertions(+), 212 deletions(-) diff --git a/dace/codegen/instrumentation/data/data_dump.py b/dace/codegen/instrumentation/data/data_dump.py index 5fc487f94d..e8c6236a01 100644 --- a/dace/codegen/instrumentation/data/data_dump.py +++ b/dace/codegen/instrumentation/data/data_dump.py @@ -1,10 +1,10 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from dace import config, data as dt, dtypes, registry, SDFG +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from dace import data as dt, dtypes, registry, SDFG from dace.sdfg import nodes, is_devicelevel_gpu from dace.codegen.prettycode import CodeIOStream from dace.codegen.instrumentation.provider import InstrumentationProvider from dace.sdfg.scope import is_devicelevel_fpga -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.codegen import common from dace.codegen import cppunparse from dace.codegen.targets import cpp @@ -101,7 +101,8 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod if sdfg.parent is None: sdfg.append_exit_code('delete __state->serializer;\n') - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream): + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream): if state.symbol_instrument == dtypes.DataInstrumentationType.No_Instrumentation: return @@ -119,17 +120,17 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea condition_preamble = f'if ({cond_string})' + ' {' condition_postamble = '}' - state_id = sdfg.node_id(state) - local_stream.write(condition_preamble, sdfg, state_id) + state_id = cfg.node_id(state) + local_stream.write(condition_preamble, cfg, state_id) defined_symbols = state.defined_symbols() for sym, _ in defined_symbols.items(): local_stream.write( - f'__state->serializer->save_symbol("{sym}", "{state_id}", {cpp.sym2cpp(sym)});\n', sdfg, state_id + f'__state->serializer->save_symbol("{sym}", "{state_id}", {cpp.sym2cpp(sym)});\n', cfg, state_id ) - local_stream.write(condition_postamble, sdfg, state_id) + local_stream.write(condition_postamble, cfg, state_id) - def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream): + def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.AccessNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream): from dace.codegen.dispatcher import DefinedType # Avoid import loop if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node): @@ -159,9 +160,9 @@ def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, oute ptrname = '&' + ptrname # Create UUID - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(node) - uuid = f'{sdfg.cfg_id}_{state_id}_{node_id}' + uuid = f'{cfg.cfg_id}_{state_id}_{node_id}' # Get optional pre/postamble for instrumenting device data preamble, postamble = '', '' @@ -174,13 +175,13 @@ def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, oute strides = ', '.join(cpp.sym2cpp(s) for s in desc.strides) # Write code - inner_stream.write(condition_preamble, sdfg, state_id, node_id) - inner_stream.write(preamble, sdfg, state_id, node_id) + inner_stream.write(condition_preamble, cfg, state_id, node_id) + inner_stream.write(preamble, cfg, state_id, node_id) inner_stream.write( f'__state->serializer->save({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, ' - f'"{node.data}", "{uuid}", {shape}, {strides});\n', sdfg, state_id, node_id) - inner_stream.write(postamble, sdfg, state_id, node_id) - inner_stream.write(condition_postamble, sdfg, state_id, node_id) + f'"{node.data}", "{uuid}", {shape}, {strides});\n', cfg, state_id, node_id) + inner_stream.write(postamble, cfg, state_id, node_id) + inner_stream.write(condition_postamble, cfg, state_id, node_id) @registry.autoregister_params(type=dtypes.DataInstrumentationType.Restore) @@ -216,7 +217,8 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod if sdfg.parent is None: sdfg.append_exit_code('delete __state->serializer;\n') - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream): + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream): if state.symbol_instrument == dtypes.DataInstrumentationType.No_Instrumentation: return @@ -234,18 +236,18 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea condition_preamble = f'if ({cond_string})' + ' {' condition_postamble = '}' - state_id = sdfg.node_id(state) - local_stream.write(condition_preamble, sdfg, state_id) + state_id = state.block_id + local_stream.write(condition_preamble, cfg, state_id) defined_symbols = state.defined_symbols() for sym, sym_type in defined_symbols.items(): local_stream.write( f'{cpp.sym2cpp(sym)} = __state->serializer->restore_symbol<{sym_type.ctype}>("{sym}", "{state_id}");\n', - sdfg, state_id + cfg, state_id ) - local_stream.write(condition_postamble, sdfg, state_id) + local_stream.write(condition_postamble, cfg, state_id) - def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream): + def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.AccessNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream): from dace.codegen.dispatcher import DefinedType # Avoid import loop if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node): @@ -275,21 +277,21 @@ def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, ou ptrname = '&' + ptrname # Create UUID - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(node) - uuid = f'{sdfg.cfg_id}_{state_id}_{node_id}' + uuid = f'{cfg.cfg_id}_{state_id}_{node_id}' # Get optional pre/postamble for instrumenting device data preamble, postamble = '', '' if desc.storage == dtypes.StorageType.GPU_Global: - self._setup_gpu_runtime(sdfg, global_stream) + self._setup_gpu_runtime(cfg, global_stream) preamble, postamble, ptrname = self._generate_copy_to_device(node, desc, ptrname) # Write code - inner_stream.write(condition_preamble, sdfg, state_id, node_id) - inner_stream.write(preamble, sdfg, state_id, node_id) + inner_stream.write(condition_preamble, cfg, state_id, node_id) + inner_stream.write(preamble, cfg, state_id, node_id) inner_stream.write( f'__state->serializer->restore({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, ' - f'"{node.data}", "{uuid}");\n', sdfg, state_id, node_id) - inner_stream.write(postamble, sdfg, state_id, node_id) - inner_stream.write(condition_postamble, sdfg, state_id, node_id) + f'"{node.data}", "{uuid}");\n', cfg, state_id, node_id) + inner_stream.write(postamble, cfg, state_id, node_id) + inner_stream.write(condition_postamble, cfg, state_id, node_id) diff --git a/dace/codegen/instrumentation/gpu_events.py b/dace/codegen/instrumentation/gpu_events.py index cfd5a1cbb3..3d367d444e 100644 --- a/dace/codegen/instrumentation/gpu_events.py +++ b/dace/codegen/instrumentation/gpu_events.py @@ -6,7 +6,7 @@ from dace.codegen import common from dace.codegen.instrumentation.provider import InstrumentationProvider from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState @registry.autoregister_params(type=dtypes.InstrumentationType.GPU_Events) @@ -53,8 +53,8 @@ def _record_event(self, id, stream): streamstr = f'__state->gpu_context->streams[{stream}]' return '%sEventRecord(__dace_ev_%s, %s);' % (self.backend, id, streamstr) - def _report(self, timer_name: str, sdfg: SDFG = None, state: SDFGState = None, node: nodes.Node = None): - idstr = self._idstr(sdfg, state, node) + def _report(self, timer_name: str, cfg: ControlFlowRegion = None, state: SDFGState = None, node: nodes.Node = None): + idstr = self._idstr(cfg, state, node) state_id = -1 node_id = -1 @@ -73,12 +73,12 @@ def _report(self, timer_name: str, sdfg: SDFG = None, state: SDFGState = None, n id=idstr, timer_name=timer_name, backend=self.backend, - cfg_id=sdfg.cfg_id, + cfg_id=cfg.cfg_id, state_id=state_id, node_id=node_id) # Code generation hooks - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) # Create GPU events for each instrumented scope in the state @@ -86,84 +86,84 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea if isinstance(node, (nodes.CodeNode, nodes.EntryNode)): s = (self._get_sobj(node) if isinstance(node, nodes.EntryNode) else node) if s.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = self._idstr(sdfg, state, node) - local_stream.write(self._create_event('b' + idstr), sdfg, state_id, node) - local_stream.write(self._create_event('e' + idstr), sdfg, state_id, node) + idstr = self._idstr(cfg, state, node) + local_stream.write(self._create_event('b' + idstr), cfg, state_id, node) + local_stream.write(self._create_event('e' + idstr), cfg, state_id, node) # Create and record a CUDA/HIP event for the entire state if state.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = 'b' + self._idstr(sdfg, state, None) - local_stream.write(self._create_event(idstr), sdfg, state_id) - local_stream.write(self._record_event(idstr, 0), sdfg, state_id) - idstr = 'e' + self._idstr(sdfg, state, None) - local_stream.write(self._create_event(idstr), sdfg, state_id) + idstr = 'b' + self._idstr(cfg, state, None) + local_stream.write(self._create_event(idstr), cfg, state_id) + local_stream.write(self._record_event(idstr, 0), cfg, state_id) + idstr = 'e' + self._idstr(cfg, state, None) + local_stream.write(self._create_event(idstr), cfg, state_id) - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) # Record and measure state stream event if state.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = self._idstr(sdfg, state, None) - local_stream.write(self._record_event('e' + idstr, 0), sdfg, state_id) - local_stream.write(self._report('State %s' % state.label, sdfg, state), sdfg, state_id) - local_stream.write(self._destroy_event('b' + idstr), sdfg, state_id) - local_stream.write(self._destroy_event('e' + idstr), sdfg, state_id) + idstr = self._idstr(cfg, state, None) + local_stream.write(self._record_event('e' + idstr, 0), cfg, state_id) + local_stream.write(self._report('State %s' % state.label, sdfg, state), cfg, state_id) + local_stream.write(self._destroy_event('b' + idstr), cfg, state_id) + local_stream.write(self._destroy_event('e' + idstr), cfg, state_id) # Destroy CUDA/HIP events for scopes in the state for node in state.nodes(): if isinstance(node, (nodes.CodeNode, nodes.EntryNode)): s = (self._get_sobj(node) if isinstance(node, nodes.EntryNode) else node) if s.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = self._idstr(sdfg, state, node) - local_stream.write(self._destroy_event('b' + idstr), sdfg, state_id, node) - local_stream.write(self._destroy_event('e' + idstr), sdfg, state_id, node) + idstr = self._idstr(cfg, state, node) + local_stream.write(self._destroy_event('b' + idstr), cfg, state_id, node) + local_stream.write(self._destroy_event('e' + idstr), cfg, state_id, node) - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.GPU_Events: if s.schedule != dtypes.ScheduleType.GPU_Device: raise TypeError('GPU Event instrumentation only applies to ' 'GPU_Device map scopes') - idstr = 'b' + self._idstr(sdfg, state, node) + idstr = 'b' + self._idstr(cfg, state, node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) entry_node = state.entry_node(node) s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = 'e' + self._idstr(sdfg, state, entry_node) + idstr = 'e' + self._idstr(cfg, state, entry_node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) - outer_stream.write(self._report('%s %s' % (type(s).__name__, s.label), sdfg, state, entry_node), sdfg, + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) + outer_stream.write(self._report('%s %s' % (type(s).__name__, s.label), cfg, state, entry_node), cfg, state_id, node) - def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if (not isinstance(node, nodes.CodeNode) or is_devicelevel_gpu(sdfg, state, node)): return # Only run for host nodes # TODO(later): Implement "clock64"-based GPU counters if node.instrument == dtypes.InstrumentationType.GPU_Events: state_id = state.parent_graph.node_id(state) - idstr = 'b' + self._idstr(sdfg, state, node) + idstr = 'b' + self._idstr(cfg, state, node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) - def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if (not isinstance(node, nodes.Tasklet) or is_devicelevel_gpu(sdfg, state, node)): return # Only run for host nodes # TODO(later): Implement "clock64"-based GPU counters if node.instrument == dtypes.InstrumentationType.GPU_Events: state_id = state.parent_graph.node_id(state) - idstr = 'e' + self._idstr(sdfg, state, node) + idstr = 'e' + self._idstr(cfg, state, node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) - outer_stream.write(self._report('%s %s' % (type(node).__name__, node.label), sdfg, state, node), sdfg, + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) + outer_stream.write(self._report('%s %s' % (type(node).__name__, node.label), cfg, state, node), cfg, state_id, node) diff --git a/dace/codegen/instrumentation/likwid.py b/dace/codegen/instrumentation/likwid.py index 8d1c9e3b71..bd9ffe63a7 100644 --- a/dace/codegen/instrumentation/likwid.py +++ b/dace/codegen/instrumentation/likwid.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Implements the LIKWID counter performance instrumentation provider. Used for collecting CPU performance counters. """ @@ -15,7 +15,7 @@ from dace.config import Config from dace.sdfg import nodes from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.transformation import helpers as xfh @@ -213,13 +213,13 @@ def on_sdfg_end(self, sdfg, local_stream, global_stream): ''' self.codegen._exitcode.write(exit_code, sdfg) - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_CPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -250,13 +250,13 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea ''' local_stream.write(marker_code) - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_CPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -269,8 +269,8 @@ def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, ''' local_stream.write(marker_code) - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used or node.instrument != dace.InstrumentationType.LIKWID_CPU: return @@ -279,7 +279,7 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou elif node.schedule not in LIKWIDInstrumentationCPU.perf_whitelist_schedules: raise TypeError("Unsupported schedule on scope") - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(node) region = f"scope_{cfg_id}_{state_id}_{node_id}" @@ -296,13 +296,13 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou ''' outer_stream.write(marker_code) - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: entry_node = state.entry_node(node) if not self._likwid_used or entry_node.instrument != dace.InstrumentationType.LIKWID_CPU: return - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(entry_node) region = f"scope_{cfg_id}_{state_id}_{node_id}" @@ -405,13 +405,13 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod ''' self.codegen._exitcode.write(exit_code, sdfg) - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_GPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -428,13 +428,13 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea ''' local_stream.write(marker_code) - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_GPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -444,8 +444,8 @@ def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, ''' local_stream.write(marker_code) - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used or node.instrument != dace.InstrumentationType.LIKWID_GPU: return @@ -454,7 +454,7 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou elif node.schedule not in LIKWIDInstrumentationGPU.perf_whitelist_schedules: raise TypeError("Unsupported schedule on scope") - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(node) region = f"scope_{cfg_id}_{state_id}_{node_id}" @@ -471,13 +471,13 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou ''' outer_stream.write(marker_code) - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: entry_node = state.entry_node(node) if not self._likwid_used or entry_node.instrument != dace.InstrumentationType.LIKWID_GPU: return - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(entry_node) region = f"scope_{cfg_id}_{state_id}_{node_id}" diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index 4885611408..ac2f6aafb7 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -103,19 +103,19 @@ def on_sdfg_end(self, sdfg, local_stream, global_stream): local_stream.write('__perf_store.flush();', sdfg) - def on_state_begin(self, sdfg, state, local_stream, global_stream): + def on_state_begin(self, sdfg, cfg, state, local_stream, global_stream): if not self._papi_used: return if state.instrument == dace.InstrumentationType.PAPI_Counters: - uid = _unified_id(-1, sdfg.node_id(state)) + uid = _unified_id(-1, cfg.node_id(state)) local_stream.write("__perf_store.markSuperSectionStart(%d);" % uid) - def on_copy_begin(self, sdfg, state, src_node, dst_node, edge, local_stream, global_stream, copy_shape, src_strides, - dst_strides): + def on_copy_begin(self, sdfg, cfg, state, src_node, dst_node, edge, local_stream, global_stream, copy_shape, + src_strides, dst_strides): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) memlet = edge.data # For perfcounters, we have to make sure that: @@ -153,7 +153,7 @@ def on_copy_begin(self, sdfg, state, src_node, dst_node, edge, local_stream, glo # would be a section with 1 entry)) local_stream.write( self.perf_section_start_string(node_id, copy_size, copy_size), - sdfg, + cfg, state_id, [src_node, dst_node], ) @@ -169,34 +169,34 @@ def on_copy_begin(self, sdfg, state, src_node, dst_node, edge, local_stream, glo unique_id=unique_cpy_id, size=copy_size, ), - sdfg, + cfg, state_id, [src_node, dst_node], ) - def on_copy_end(self, sdfg, state, src_node, dst_node, edge, local_stream, global_stream): + def on_copy_end(self, sdfg, cfg, state, src_node, dst_node, edge, local_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(dst_node) if self.perf_should_instrument: unique_cpy_id = self._unique_counter local_stream.write( "__perf_cpy_%d_%d.leaveCritical(__vs_cpy_%d_%d);" % (node_id, unique_cpy_id, node_id, unique_cpy_id), - sdfg, + cfg, state_id, [src_node, dst_node], ) self.perf_should_instrument = False - def on_node_begin(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_begin(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) unified_id = _unified_id(state.node_id(node), state_id) perf_should_instrument = (node.instrument == dace.InstrumentationType.PAPI_Counters @@ -207,25 +207,25 @@ def on_node_begin(self, sdfg, state, node, outer_stream, inner_stream, global_st if isinstance(node, nodes.Tasklet): inner_stream.write( "dace::perf::%s __perf_%s;\n" % (self.perf_counter_string(), node.label), - sdfg, + cfg, state_id, node, ) inner_stream.write( 'auto& __perf_vs_%s = __perf_store.getNewValueSet(__perf_%s, ' ' %d, PAPI_thread_id(), 0);\n' % (node.label, node.label, unified_id), - sdfg, + cfg, state_id, node, ) - inner_stream.write("__perf_%s.enterCritical();\n" % node.label, sdfg, state_id, node) + inner_stream.write("__perf_%s.enterCritical();\n" % node.label, cfg, state_id, node) - def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_end(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(node) unified_id = _unified_id(node_id, state_id) @@ -234,7 +234,7 @@ def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stre if not PAPIInstrumentation.has_surrounding_perfcounters(node, state): inner_stream.write( "__perf_%s.leaveCritical(__perf_vs_%s);" % (node.label, node.label), - sdfg, + cfg, state_id, node, ) @@ -242,21 +242,21 @@ def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stre # Add bytes moved inner_stream.write( "__perf_store.addBytesMoved(%s);" % - PAPIUtils.get_tasklet_byte_accesses(node, state, sdfg, state_id), sdfg, state_id, node) + PAPIUtils.get_tasklet_byte_accesses(node, state, sdfg, cfg, state_id), cfg, state_id, node) - def on_scope_entry(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return if isinstance(node, nodes.MapEntry): - return self.on_map_entry(sdfg, state, node, outer_stream, inner_stream) + return self.on_map_entry(sdfg, cfg, state, node, outer_stream, inner_stream) elif isinstance(node, nodes.ConsumeEntry): - return self.on_consume_entry(sdfg, state, node, outer_stream, inner_stream) + return self.on_consume_entry(sdfg, cfg, state, node, outer_stream, inner_stream) raise TypeError - def on_scope_exit(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_exit(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) entry_node = state.entry_node(node) if not self.should_instrument_entry(entry_node): return @@ -265,11 +265,11 @@ def on_scope_exit(self, sdfg, state, node, outer_stream, inner_stream, global_st perf_end_string = self.perf_counter_end_measurement_string(unified_id) # Inner part - inner_stream.write(perf_end_string, sdfg, state_id, node) + inner_stream.write(perf_end_string, cfg, state_id, node) - def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): + def on_map_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream): dfg = state.scope_subgraph(node) - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) if node.map.instrument != dace.InstrumentationType.PAPI_Counters: return @@ -280,7 +280,7 @@ def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): result = outer_stream - input_size: str = PAPIUtils.get_memory_input_size(node, sdfg, state_id) + input_size: str = PAPIUtils.get_memory_input_size(node, sdfg, cfg, state_id) # Emit supersection if possible result.write(self.perf_get_supersection_start_string(node, dfg, unified_id)) @@ -288,7 +288,7 @@ def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): if not self.should_instrument_entry(node): return - size = PAPIUtils.accumulate_byte_movement(node, node, dfg, sdfg, state_id) + size = PAPIUtils.accumulate_byte_movement(node, node, dfg, sdfg, cfg, state_id) size = sym2cpp(sp.simplify(size)) result.write(self.perf_section_start_string(unified_id, size, input_size)) @@ -299,10 +299,10 @@ def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): map_name = node.map.params[-1] - result.write(self.perf_counter_start_measurement_string(unified_id, map_name), sdfg, state_id, node) + result.write(self.perf_counter_start_measurement_string(unified_id, map_name), cfg, state_id, node) - def on_consume_entry(self, sdfg, state, node, outer_stream, inner_stream): - state_id = sdfg.node_id(state) + def on_consume_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream): + state_id = cfg.node_id(state) unified_id = _unified_id(state.node_id(node), state_id) # Outer part @@ -312,18 +312,18 @@ def on_consume_entry(self, sdfg, state, node, outer_stream, inner_stream): # Mark the SuperSection start (if possible) result.write( self.perf_get_supersection_start_string(node, state, unified_id), - sdfg, + cfg, state_id, node, ) # Mark the section start with zeros (due to dynamic accesses) - result.write(self.perf_section_start_string(unified_id, "0", "0"), sdfg, state_id, node) + result.write(self.perf_section_start_string(unified_id, "0", "0"), cfg, state_id, node) # Generate a thread affinity locker result.write( "dace::perf::ThreadLockProvider __perf_tlp_%d;\n" % unified_id, - sdfg, + cfg, state_id, node, ) @@ -343,7 +343,7 @@ def on_consume_entry(self, sdfg, state, node, outer_stream, inner_stream): "__perf_tlp_{id}.getAndIncreaseCounter()".format(id=unified_id), core_str="dace::perf::getThreadID()", ), - sdfg, + cfg, state_id, node, ) @@ -605,8 +605,8 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): - scope_dict = sdfg.node(state_id).scope_dict() + def get_out_memlet_costs(sdfg: dace.SDFG, cfg, state_id: int, node: nodes.Node, dfg: DataflowGraphView): + scope_dict = cfg.node(state_id).scope_dict() out_costs = 0 for edge in dfg.out_edges(node): @@ -639,6 +639,7 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: DataflowGraphView, sdfg: dace.SDFG, + cfg, state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ @@ -649,7 +650,7 @@ def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, for ie in in_edges: in_accum.append(PAPIUtils.get_memlet_byte_size(sdfg, ie.data)) - out_accum.append(PAPIUtils.get_out_memlet_costs(sdfg, state_id, tasklet, dfg)) + out_accum.append(PAPIUtils.get_out_memlet_costs(sdfg, cfg, state_id, tasklet, dfg)) # Merge full = in_accum @@ -663,7 +664,7 @@ def get_parents(outermost_node: nodes.Node, node: nodes.Node, sdfg: dace.SDFG, s parent = None # Because dfg is only a subgraph view, it does not contain the entry # node for a given entry. This O(n) solution is suboptimal - for state in sdfg.nodes(): + for state in sdfg.states(): s_d = state.scope_dict() try: scope = s_d[node] @@ -681,8 +682,8 @@ def get_parents(outermost_node: nodes.Node, node: nodes.Node, sdfg: dace.SDFG, s return PAPIUtils.get_parents(outermost_node, parent, sdfg, state_id) + [parent] @staticmethod - def get_memory_input_size(node, sdfg, state_id) -> str: - curr_state = sdfg.nodes()[state_id] + def get_memory_input_size(node, sdfg, cfg, state_id) -> str: + curr_state = cfg.node(state_id) input_size = 0 for edge in curr_state.in_edges(node): @@ -696,7 +697,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, cfg, state_id): itvars = dict() # initialize an empty dict @@ -711,7 +712,7 @@ def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, if len(children) > 0: size = 0 for x in children: - size = size + PAPIUtils.accumulate_byte_movement(outermost_node, x, dfg, sdfg, state_id) + size = size + PAPIUtils.accumulate_byte_movement(outermost_node, x, dfg, sdfg, cfg, state_id) return size else: @@ -740,7 +741,7 @@ def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, return 0 # We can ignore this. elif isinstance(node, Tasklet): return itcount * symbolic.pystr_to_symbolic( - PAPIUtils.get_tasklet_byte_accesses(node, dfg, sdfg, state_id)) + PAPIUtils.get_tasklet_byte_accesses(node, dfg, sdfg, cfg, state_id)) elif isinstance(node, nodes.AccessNode): return 0 else: diff --git a/dace/codegen/instrumentation/provider.py b/dace/codegen/instrumentation/provider.py index a3748b241b..9374ed60dd 100644 --- a/dace/codegen/instrumentation/provider.py +++ b/dace/codegen/instrumentation/provider.py @@ -60,34 +60,37 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod """ pass - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the beginning of SDFG state code generation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param local_stream: Code generator for the in-function code. :param global_stream: Code generator for global (external) code. """ pass - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the end of SDFG state code generation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param local_stream: Code generator for the in-function code. :param global_stream: Code generator for global (external) code. """ pass - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the beginning of a scope (on generating an EntryNode). :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The EntryNode object from which code is generated. :param outer_stream: Code generator for the internal code before @@ -98,11 +101,12 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou """ pass - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the end of a scope (on generating an ExitNode). :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The ExitNode object from which code is generated. :param outer_stream: Code generator for the internal code after @@ -113,12 +117,13 @@ def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, oute """ pass - def on_copy_begin(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_node: nodes.Node, - edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, global_stream: CodeIOStream, - copy_shape, src_strides, dst_strides) -> None: + def on_copy_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream, copy_shape, src_strides, dst_strides) -> None: """ Event called at the beginning of generating a copy operation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param src_node: The source node of the copy. :param dst_node: The destination node of the copy. @@ -131,11 +136,13 @@ def on_copy_begin(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_ """ pass - def on_copy_end(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_node: nodes.Node, - edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_copy_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: """ Event called at the end of generating a copy operation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param src_node: The source node of the copy. :param dst_node: The destination node of the copy. @@ -145,11 +152,12 @@ def on_copy_end(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_no """ pass - def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the beginning of generating a node. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The generated node. :param outer_stream: Code generator for the internal code before @@ -160,11 +168,12 @@ def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_st """ pass - def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the end of generating a node. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The generated node. :param outer_stream: Code generator for the internal code after diff --git a/dace/codegen/instrumentation/timer.py b/dace/codegen/instrumentation/timer.py index a13e50faca..fea2cf70ea 100644 --- a/dace/codegen/instrumentation/timer.py +++ b/dace/codegen/instrumentation/timer.py @@ -16,24 +16,24 @@ def on_sdfg_begin(self, sdfg, local_stream, global_stream, codegen): sdfg.append_global_code('\n#include ', None) if sdfg.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(local_stream, sdfg) + self.on_tbegin(local_stream, sdfg, sdfg) def on_sdfg_end(self, sdfg, local_stream, global_stream): if sdfg.instrument == dtypes.InstrumentationType.Timer: - self.on_tend('SDFG %s' % sdfg.name, local_stream, sdfg) + self.on_tend('SDFG %s' % sdfg.name, local_stream, sdfg, sdfg) - def on_tbegin(self, stream: CodeIOStream, sdfg=None, state=None, node=None): - idstr = self._idstr(sdfg, state, node) + def on_tbegin(self, stream: CodeIOStream, sdfg=None, cfg=None, state=None, node=None): + idstr = self._idstr(cfg, state, node) stream.write('auto __dace_tbegin_%s = std::chrono::high_resolution_clock::now();' % idstr) - def on_tend(self, timer_name: str, stream: CodeIOStream, sdfg=None, state=None, node=None): - idstr = self._idstr(sdfg, state, node) + def on_tend(self, timer_name: str, stream: CodeIOStream, sdfg=None, cfg=None, state=None, node=None): + idstr = self._idstr(cfg, state, node) state_id = -1 node_id = -1 if state is not None: - state_id = sdfg.node_id(state) + state_id = state.block_id if node is not None: node_id = state.node_id(node) @@ -41,16 +41,16 @@ def on_tend(self, timer_name: str, stream: CodeIOStream, sdfg=None, state=None, unsigned long int __dace_ts_start_{id} = std::chrono::duration_cast(__dace_tbegin_{id}.time_since_epoch()).count(); unsigned long int __dace_ts_end_{id} = std::chrono::duration_cast(__dace_tend_{id}.time_since_epoch()).count(); __state->report.add_completion("{timer_name}", "Timer", __dace_ts_start_{id}, __dace_ts_end_{id}, {cfg_id}, {state_id}, {node_id});''' - .format(timer_name=timer_name, id=idstr, cfg_id=sdfg.cfg_id, state_id=state_id, node_id=node_id)) + .format(timer_name=timer_name, id=idstr, cfg_id=cfg.cfg_id, state_id=state_id, node_id=node_id)) # Code generation hooks - def on_state_begin(self, sdfg, state, local_stream, global_stream): + def on_state_begin(self, sdfg, cfg, state, local_stream, global_stream): if state.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(local_stream, sdfg, state) + self.on_tbegin(local_stream, sdfg, cfg, state) - def on_state_end(self, sdfg, state, local_stream, global_stream): + def on_state_end(self, sdfg, cfg, state, local_stream, global_stream): if state.instrument == dtypes.InstrumentationType.Timer: - self.on_tend('State %s' % state.label, local_stream, sdfg, state) + self.on_tend('State %s' % state.label, local_stream, sdfg, cfg, state) def _get_sobj(self, node): # Get object behind scope @@ -59,26 +59,26 @@ def _get_sobj(self, node): else: return node.map - def on_scope_entry(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(outer_stream, sdfg, state, node) + self.on_tbegin(outer_stream, sdfg, cfg, state, node) - def on_scope_exit(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_exit(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): entry_node = state.entry_node(node) s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.Timer: - self.on_tend('%s %s' % (type(s).__name__, s.label), outer_stream, sdfg, state, entry_node) + self.on_tend('%s %s' % (type(s).__name__, s.label), outer_stream, sdfg, cfg, state, entry_node) - def on_node_begin(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_begin(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not isinstance(node, CodeNode): return if node.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(outer_stream, sdfg, state, node) + self.on_tbegin(outer_stream, sdfg, cfg, state, node) - def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_end(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not isinstance(node, CodeNode): return if node.instrument == dtypes.InstrumentationType.Timer: - idstr = self._idstr(sdfg, state, node) - self.on_tend('%s %s' % (type(node).__name__, idstr), outer_stream, sdfg, state, node) + idstr = self._idstr(cfg, state, node) + self.on_tend('%s %s' % (type(node).__name__, idstr), outer_stream, sdfg, cfg, state, node) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 51daaa432b..fad672ffc1 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -1852,7 +1852,7 @@ def _generate_MapEntry( # Instrumentation: Pre-scope instr = self._dispatcher.instrumentation[node.map.instrument] if instr is not None: - instr.on_scope_entry(sdfg, state_dfg, node, callsite_stream, inner_stream, function_stream) + instr.on_scope_entry(sdfg, cfg, state_dfg, node, callsite_stream, inner_stream, function_stream) # TODO: Refactor to generate_scope_preamble once a general code # generator (that CPU inherits from) is implemented @@ -1966,7 +1966,7 @@ def _generate_MapExit(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra # Instrumentation: Post-scope instr = self._dispatcher.instrumentation[node.map.instrument] if instr is not None and not is_devicelevel_gpu(sdfg, state_dfg, node): - instr.on_scope_exit(sdfg, state_dfg, node, outer_stream, callsite_stream, function_stream) + instr.on_scope_exit(sdfg, cfg, state_dfg, node, outer_stream, callsite_stream, function_stream) self.generate_scope_postamble(sdfg, dfg, state_id, function_stream, outer_stream, callsite_stream) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index f080f2cc62..23d48fe9ea 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1304,7 +1304,7 @@ def generate_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, # Invoke all instrumentation providers for instr in self._frame._dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_end(sdfg, state, callsite_stream, function_stream) + instr.on_state_end(sdfg, cfg, state, callsite_stream, function_stream) def generate_devicelevel_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: @@ -1434,8 +1434,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub create_grid_barrier = True self.create_grid_barrier = create_grid_barrier - kernel_name = '%s_%d_%d_%d' % (scope_entry.map.label, sdfg.cfg_id, sdfg.node_id(state), - state.node_id(scope_entry)) + kernel_name = '%s_%d_%d_%d' % (scope_entry.map.label, sdfg.cfg_id, state.block_id, state.node_id(scope_entry)) # Comprehend grid/block dimensions from scopes grid_dims, block_dims, tbmap, dtbmap, _ = self.get_kernel_dimensions(dfg_scope) @@ -1495,9 +1494,10 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # Instrumentation for kernel scope instr = self._dispatcher.instrumentation[scope_entry.map.instrument] if instr is not None: - instr.on_scope_entry(sdfg, state, scope_entry, callsite_stream, self.scope_entry_stream, self._globalcode) + instr.on_scope_entry(sdfg, cfg, state, scope_entry, callsite_stream, self.scope_entry_stream, + self._globalcode) outer_stream = CodeIOStream() - instr.on_scope_exit(sdfg, state, scope_exit, outer_stream, self.scope_exit_stream, self._globalcode) + instr.on_scope_exit(sdfg, cfg, state, scope_exit, outer_stream, self.scope_exit_stream, self._globalcode) # Redefine constant arguments and rename arguments to device counterparts # TODO: This (const behavior and code below) is all a hack. @@ -1586,7 +1586,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # Write kernel prototype self._localcode.write( '__global__ void %s %s(%s) {\n' % - (launch_bounds, kernel_name, ', '.join(kernel_args_typed + extra_kernel_args_typed)), sdfg, state_id, node) + (launch_bounds, kernel_name, ', '.join(kernel_args_typed + extra_kernel_args_typed)), cfg, state_id, node) # Write constant expressions in GPU code self._frame.generate_constants(sdfg, self._localcode) @@ -2034,7 +2034,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S expr = _topy(bidx[i]).replace('__DAPB%d' % i, block_expr) - kernel_stream.write(f'{tidtype.ctype} {varname} = {expr};', sdfg, state_id, node) + kernel_stream.write(f'{tidtype.ctype} {varname} = {expr};', cfg, state_id, node) self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, tidtype.ctype) # Delinearize beyond the third dimension diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 769310c655..54273ac6b0 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import collections import copy import re @@ -16,7 +16,6 @@ from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator from dace.codegen.tools.type_inference import infer_expr_type -from dace.frontend.python import astutils from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils @@ -423,7 +422,7 @@ def generate_state(self, # Invoke all instrumentation providers for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_begin(sdfg, state, callsite_stream, global_stream) + instr.on_state_begin(sdfg, cfg, state, callsite_stream, global_stream) ##################### # Create dataflow graph for state's children. @@ -470,7 +469,7 @@ def generate_state(self, # Invoke all instrumentation providers for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_end(sdfg, state, callsite_stream, global_stream) + instr.on_state_end(sdfg, cfg, state, callsite_stream, global_stream) def generate_states(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stream: CodeIOStream) -> Set[SDFGState]: states_generated = set() diff --git a/dace/codegen/targets/snitch.py b/dace/codegen/targets/snitch.py index 5a62ca2995..bcdcb61941 100644 --- a/dace/codegen/targets/snitch.py +++ b/dace/codegen/targets/snitch.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from typing import Union import dace @@ -201,7 +201,7 @@ def generate_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, g # Invoke all instrumentation providers for instr in self.dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_begin(sdfg, state, callsite_stream, global_stream) + instr.on_state_begin(sdfg, cfg, state, callsite_stream, global_stream) ##################### # Create dataflow graph for state's children. diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index a272057380..8a1fb68081 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1425,11 +1425,12 @@ def get_next_nonempty_states(sdfg: SDFG, state: SDFGState) -> Set[SDFGState]: result: Set[SDFGState] = set() # Traverse children until states are not empty - for succ in sdfg.successors(state): - result |= set(dfs_conditional(sdfg, sources=[succ], condition=lambda parent, _: parent.is_empty())) + for succ in state.parent_graph.successors(state): + result |= set(dfs_conditional(state.parent_graph, sources=[succ], + condition=lambda parent, _: parent.number_of_nodes() == 0)) # Filter out empty states - result = {s for s in result if not s.is_empty()} + result = {s for s in result if not s.number_of_nodes() == 0} return result diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 4e8da6ee42..f6311dea6f 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -126,7 +126,7 @@ def _add_nested_datanames(name: str, desc: data.Structure): if isinstance(block, SDFGState): # Replace in state contents block.replace_dict(mapping) - elif isinstance(block, ControlFlowRegion): + elif isinstance(block, AbstractControlFlowRegion): block.replace_dict(mapping, replace_in_graph=False, replace_keys=False) # Replace in outgoing edges as well diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index d6e1f4c460..12a2b2908a 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -61,6 +61,7 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: conditional.add_branch(oe.data.condition, branch) if oe.dst is merge_block: # Empty branch. + branch.add_state('noop') continue branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) diff --git a/tests/codegen/allocation_lifetime_test.py b/tests/codegen/allocation_lifetime_test.py index 9a68cd2140..e87e6bf109 100644 --- a/tests/codegen/allocation_lifetime_test.py +++ b/tests/codegen/allocation_lifetime_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests different allocation lifetimes. """ import pytest diff --git a/tests/codegen/control_flow_detection_test.py b/tests/codegen/control_flow_detection_test.py index e97f7db77b..feeb6ff4fc 100644 --- a/tests/codegen/control_flow_detection_test.py +++ b/tests/codegen/control_flow_detection_test.py @@ -1,5 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from math import exp +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace diff --git a/tests/fortran/fortran_language_test.py b/tests/fortran/fortran_language_test.py index 32ab23714b..60ffcb36c4 100644 --- a/tests/fortran/fortran_language_test.py +++ b/tests/fortran/fortran_language_test.py @@ -1,22 +1,8 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.common.readfortran import FortranStringReader -from fparser.common.readfortran import FortranFileReader -from fparser.two.parser import ParserFactory -import sys, os import numpy as np -import pytest - -from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic from dace.frontend.fortran import fortran_parser -from fparser.two.symbol_table import SymbolTable -from dace.sdfg import utils as sdutil - -import dace.frontend.fortran.ast_components as ast_components -import dace.frontend.fortran.ast_transforms as ast_transforms -import dace.frontend.fortran.ast_utils as ast_utils -import dace.frontend.fortran.ast_internal_classes as ast_internal_classes def test_fortran_frontend_real_kind_selector(): From 9185897a23d99d562b1a313b3f3f57c57978308a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 16 Oct 2024 17:33:07 +0200 Subject: [PATCH 050/108] Adapt composite fusion --- dace/transformation/subgraph/composite.py | 7 +++--- dace/transformation/subgraph/expansion.py | 5 ++-- .../subgraph/subgraph_fusion.py | 23 ++++++++----------- tests/codegen/control_flow_detection_test.py | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/dace/transformation/subgraph/composite.py b/dace/transformation/subgraph/composite.py index e25ccd192a..886698f791 100644 --- a/dace/transformation/subgraph/composite.py +++ b/dace/transformation/subgraph/composite.py @@ -3,6 +3,7 @@ Subgraph Fusion - Stencil Tiling Transformation """ +from dace.sdfg.state import SDFGState, StateSubgraphView from dace.transformation.subgraph import SubgraphFusion, MultiExpansion from dace.transformation.subgraph.stencil_tiling import StencilTiling from dace.transformation.subgraph import helpers @@ -18,7 +19,7 @@ @make_properties -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class CompositeFusion(transformation.SubgraphTransformation): """ MultiExpansion + SubgraphFusion in one Transformation Additional StencilTiling is also possible as a canonicalizing @@ -46,8 +47,8 @@ class CompositeFusion(transformation.SubgraphTransformation): expansion_split = Property(desc="Allow MultiExpansion to split up maps, if enabled", dtype=bool, default=True) - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: - graph = subgraph.graph + def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: + graph: SDFGState = subgraph.graph if self.allow_expansion == True: subgraph_fusion = SubgraphFusion() subgraph_fusion.setup_match(subgraph) diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index aa182e8c80..b013627d2e 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -6,6 +6,7 @@ from dace.sdfg import nodes from dace.sdfg import replace, SDFG, dynamic_map_inputs from dace.sdfg.graph import SubgraphView +from dace.sdfg.state import SDFGState, StateSubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property from dace.transformation.subgraph import helpers @@ -58,12 +59,12 @@ class MultiExpansion(transformation.SubgraphTransformation): allow_offset = Property(dtype=bool, desc="Offset ranges to zero", default=True) - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: + def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: # get lowest scope maps of subgraph # grab first node and see whether all nodes are in the same graph # (or nested sdfgs therein) - graph = subgraph.graph + graph: SDFGState = subgraph.graph # next, get all the maps by obtaining a copy (for potential offsets) map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 1ff286b85c..6b78e7276c 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -1,24 +1,21 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement subgraph fusion. """ import dace import networkx as nx -from dace import dtypes, registry, symbolic, subsets, data -from dace.sdfg import nodes, utils, replace, SDFG, scope_contains_scope -from dace.sdfg.graph import SubgraphView -from dace.sdfg.scope import ScopeTree +from dace import dtypes, symbolic, subsets, data +from dace.sdfg import nodes, SDFG from dace.memlet import Memlet +from dace.sdfg.state import SDFGState, StateSubgraphView from dace.transformation import transformation from dace.properties import EnumProperty, ListProperty, make_properties, Property -from dace.symbolic import overapproximate -from dace.sdfg.propagation import propagate_memlets_sdfg, propagate_memlet, propagate_memlets_scope, _propagate_node +from dace.sdfg.propagation import _propagate_node from dace.transformation.subgraph import helpers -from dace.transformation.dataflow import RedundantArray -from dace.sdfg.utils import consolidate_edges_scope, get_view_node +from dace.sdfg.utils import consolidate_edges_scope from dace.transformation.helpers import find_contiguous_subsets from copy import deepcopy as dcpy -from typing import List, Union, Tuple +from typing import List, Tuple import warnings import dace.libraries.standard as stdlib @@ -74,7 +71,7 @@ class SubgraphFusion(transformation.SubgraphTransformation): desc="A list of array names to treat as non-transients and not compress", ) - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: + def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: """ Fusible if @@ -89,7 +86,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: 4. Check for any disjoint accesses of arrays. """ # get graph - graph = subgraph.graph + graph: SDFGState = subgraph.graph for node in subgraph.nodes(): if node not in graph.nodes(): return False @@ -626,7 +623,7 @@ def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, # do a full global search and count each data from each intermediate node scope_dict = graph.scope_dict() - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.data in data_intermediate: # add them to the counter set in all cases diff --git a/tests/codegen/control_flow_detection_test.py b/tests/codegen/control_flow_detection_test.py index feeb6ff4fc..aaf0e11d42 100644 --- a/tests/codegen/control_flow_detection_test.py +++ b/tests/codegen/control_flow_detection_test.py @@ -65,7 +65,7 @@ def looptest(): sdfg: dace.SDFG = looptest.to_sdfg(simplify=True) if dace.Config.get_bool('optimizer', 'detect_control_flow'): - assert 'for (' in sdfg.generate_code()[0].code + assert 'while (' in sdfg.generate_code()[0].code A = looptest() A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) From bf7d82261036eee693de7bdc54d2e6329afe53bc Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 17 Oct 2024 11:10:09 +0200 Subject: [PATCH 051/108] Fix StencilTiling --- dace/sdfg/analysis/cutout.py | 2 +- dace/transformation/dataflow/__init__.py | 2 +- dace/transformation/dataflow/map_fission.py | 4 ++-- dace/transformation/dataflow/map_for_loop.py | 23 +------------------ .../transformation/subgraph/stencil_tiling.py | 19 +++++---------- dace/transformation/transformation.py | 7 ------ 6 files changed, 11 insertions(+), 46 deletions(-) diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index 5d2eae7c6f..48b8e98f98 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -536,7 +536,7 @@ def _transformation_determine_affected_nodes( if transformation.cfg_id >= 0 and target_sdfg.cfg_list: target_sdfg = target_sdfg.cfg_list[transformation.cfg_id] - subgraph = transformation.get_subgraph(target_sdfg) + subgraph = transformation.subgraph_view(target_sdfg) for n in subgraph.nodes(): affected_nodes.add(n) diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index 6fa274f041..e12ee8e1a9 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -5,7 +5,7 @@ from .mapreduce import MapReduceFusion, MapWCRFusion from .map_expansion import MapExpansion from .map_collapse import MapCollapse -from .map_for_loop import MapToForLoop, MapToForLoopRegion +from .map_for_loop import MapToForLoop from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle from .map_fusion import MapFusion diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index f3a2be08b7..f0e8499c92 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -64,7 +64,7 @@ def _components(subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node] return ns @staticmethod - def _border_arrays(sdfg, parent, subgraph): + def _border_arrays(sdfg: sd.SDFG, parent, subgraph): """ Returns a set of array names that are local to the fission subgraph. """ nested = isinstance(parent, sd.SDFGState) @@ -175,7 +175,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Find all nodes not in subgraph not_subgraph = set(n.data for n in graph.nodes() if n not in snodes and isinstance(n, nodes.AccessNode)) not_subgraph.update( - set(n.data for s in sdfg.nodes() if s != graph for n in s.nodes() + set(n.data for s in sdfg.states() if s != graph for n in s.nodes() if isinstance(n, nodes.AccessNode))) for _, component_out in components: diff --git a/dace/transformation/dataflow/map_for_loop.py b/dace/transformation/dataflow/map_for_loop.py index d7148fc651..f224e9dbcf 100644 --- a/dace/transformation/dataflow/map_for_loop.py +++ b/dace/transformation/dataflow/map_for_loop.py @@ -12,7 +12,7 @@ from typing import Tuple, Optional -class MapToForLoopRegion(transformation.SingleStateTransformation): +class MapToForLoop(transformation.SingleStateTransformation): """ Implements the Map to for-loop transformation. Takes a map and enforces a sequential schedule by transforming it into a loop region. Creates a nested SDFG, if @@ -115,24 +115,3 @@ def replace_param(param): sdfg.root_sdfg.using_experimental_blocks = True return node, nstate - - -class MapToForLoop(MapToForLoopRegion): - """ Implements the Map to for-loop transformation. - - Takes a map and enforces a sequential schedule by transforming it into - a state-machine of a for-loop. Creates a nested SDFG, if necessary. - """ - - before_state: SDFGState - guard: SDFGState - after_state: SDFGState - - def apply(self, graph: SDFGState, sdfg: SDFG) -> Tuple[nodes.NestedSDFG, SDFGState]: - node, nstate = super().apply(graph, sdfg) - _, (self.before_state, self.guard, self.after_state) = self.loop_region.inline() - - sdfg.reset_cfg_list() - sdfg.recheck_using_experimental_blocks() - - return node, nstate diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 018bc723f3..29989292be 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -6,6 +6,7 @@ from dace import dtypes, symbolic from dace.properties import make_properties, Property, ShapeProperty from dace.sdfg import nodes +from dace.sdfg.state import SDFGState from dace.transformation import transformation from dace.sdfg.propagation import _propagate_node @@ -302,8 +303,8 @@ def can_be_applied(sdfg, subgraph) -> bool: return True def apply(self, sdfg): - graph = sdfg.node(self.state_id) subgraph = self.subgraph_view(sdfg) + graph: SDFGState = subgraph.graph map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph) result = StencilTiling.topology(sdfg, graph, map_entries) @@ -427,7 +428,7 @@ def apply(self, sdfg): stripmine_subgraph = {StripMining.map_entry: graph.node_id(map_entry)} - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id last_map_entry = None original_schedule = map_entry.schedule self.tile_sizes = [] @@ -554,7 +555,7 @@ def apply(self, sdfg): if l > 1: subgraph = {MapExpansion.map_entry: graph.node_id(map_entry)} trafo_expansion = MapExpansion() - trafo_expansion.setup_match(sdfg, sdfg.cfg_id, sdfg.nodes().index(graph), subgraph, 0) + trafo_expansion.setup_match(sdfg, graph.parent_graph.cfg_id, graph.block_id, subgraph, 0) trafo_expansion.apply(graph, sdfg) maps = [map_entry] for _ in range(l - 1): @@ -565,7 +566,7 @@ def apply(self, sdfg): # MapToForLoop subgraph = {MapToForLoop.map_entry: graph.node_id(map)} trafo_for_loop = MapToForLoop() - trafo_for_loop.setup_match(sdfg, sdfg.cfg_id, sdfg.nodes().index(graph), subgraph, 0) + trafo_for_loop.setup_match(sdfg, graph.parent_graph.cfg_id, graph.block_id, subgraph, 0) trafo_for_loop.apply(graph, sdfg) nsdfg = trafo_for_loop.nsdfg @@ -573,15 +574,7 @@ def apply(self, sdfg): # Prevent circular import from dace.transformation.interstate.loop_unroll import LoopUnroll - guard = trafo_for_loop.guard - end = trafo_for_loop.after_state - begin = next(e.dst for e in nsdfg.out_edges(guard) if e.dst != end) - - subgraph = { - DetectLoop.loop_guard: nsdfg.node_id(guard), - DetectLoop.loop_begin: nsdfg.node_id(begin), - DetectLoop.exit_state: nsdfg.node_id(end) - } + subgraph = { LoopUnroll.loop: trafo_for_loop.loop_region.block_id } transformation = LoopUnroll() transformation.setup_match(nsdfg, nsdfg.cfg_id, -1, subgraph, 0) transformation.apply(nsdfg, nsdfg) diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 3b89612026..0cec526a2c 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -824,13 +824,6 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], cfg_id: int = self.cfg_id = cfg_id self.state_id = state_id - def get_subgraph(self, sdfg: SDFG) -> gr.SubgraphView: - sdfg = sdfg.cfg_list[self.cfg_id] - if self.state_id == -1: - return gr.SubgraphView(sdfg, list(map(sdfg.node, self.subgraph))) - state = sdfg.node(self.state_id) - return st.StateSubgraphView(state, list(map(state.node, self.subgraph))) - @classmethod def subclasses_recursive(cls) -> Set[Type['PatternTransformation']]: """ From 78803d53fe5b916cb298015af85602bd501a5685 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 17 Oct 2024 11:41:00 +0200 Subject: [PATCH 052/108] Add some region inlining --- dace/frontend/python/parser.py | 1 - dace/sdfg/utils.py | 41 ++++-------- dace/transformation/pass_pipeline.py | 3 + dace/transformation/passes/fusion_inline.py | 71 ++++++++++++++++++++- dace/transformation/passes/simplify.py | 26 ++++++-- tests/sdfg/control_flow_inline_test.py | 16 ++--- 6 files changed, 114 insertions(+), 44 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 4a650325bd..ebe8b04518 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -495,7 +495,6 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF if not self.use_experimental_cfg_blocks: for nsdfg in sdfg.all_sdfgs_recursive(): - sdutils.inline_conditional_blocks(nsdfg) sdutils.inline_control_flow_regions(nsdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 8a1fb68081..dd91678fcb 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,13 +13,13 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import (ConditionalBlock, ControlFlowBlock, SDFGState, StateSubgraphView, LoopRegion, +from dace.sdfg.state import (AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion) from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs from dace.cli.progress import optional_progressbar -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Sequence, Tuple, Type, Union def node_path_graph(*args) -> gr.OrderedDiGraph: @@ -1263,21 +1263,16 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> return counter -def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, LoopRegion)] - count = 0 - - for _block in optional_progressbar(reversed(blocks), title='Inlining Loops', - n=len(blocks), progress=progress): - block: LoopRegion = _block - if block.inline()[0]: - count += 1 - - return count - - -def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ControlFlowRegion)] +def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractControlFlowRegion]]] = None, + blacklist: Optional[List[Type[AbstractControlFlowRegion]]] = None, + progress: bool = None) -> int: + if types: + blocks = [n for n, _ in sdfg.all_nodes_recursive() if type(n) in types] + elif blacklist: + blocks = [n for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, AbstractControlFlowRegion) and type(n) not in blacklist] + else: + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, AbstractControlFlowRegion)] count = 0 for _block in optional_progressbar(reversed(blocks), title='Inlining control flow regions', @@ -1288,18 +1283,6 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: return count -def inline_conditional_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)] - count = 0 - - for _block in optional_progressbar(reversed(blocks), title='Inlining conditional blocks', - n=len(blocks), progress=progress): - block: ConditionalBlock = _block - if block.inline()[0]: - count += 1 - - return count - def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: """ diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 9651e7d208..30026bba6d 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -134,6 +134,9 @@ def subclasses_recursive(cls) -> Set[Type['Pass']]: return result + def set_opts(self, opts: Dict[str, Any]) -> None: + pass + @properties.make_properties class VisitorPass(Pass): """ diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 15d1eeca5f..5cf75c6b62 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -8,7 +8,8 @@ from dace import SDFG, properties from dace.sdfg import nodes -from dace.sdfg.utils import fuse_states, inline_sdfgs +from dace.sdfg.state import ConditionalBlock, FunctionCallRegion, LoopRegion, NamedRegion +from dace.sdfg.utils import fuse_states, inline_control_flow_regions, inline_sdfgs from dace.transformation import pass_pipeline as ppl from dace.transformation.transformation import experimental_cfg_block_compatible @@ -88,6 +89,74 @@ def report(self, pass_retval: int) -> str: return f'Inlined {pass_retval} SDFGs.' +@dataclass(unsafe_hash=True) +@properties.make_properties +@experimental_cfg_block_compatible +class InlineControlFlowRegions(ppl.Pass): + """ + Inlines all control flow regions. + """ + + CATEGORY: str = 'Simplification' + + progress = properties.Property(dtype=bool, + default=None, + allow_none=True, + desc='Whether to print progress, or None for default (print after 5 seconds).') + + no_inline_loops = properties.Property(dtype=bool, default=True, desc='Whether to prevent inlining loops.') + no_inline_conditional = properties.Property(dtype=bool, default=True, + desc='Whether to prevent inlining conditional blocks.') + no_inline_function_call_regions = properties.Property(dtype=bool, default=True, + desc='Whether to prevent inlining function call regions.') + no_inline_named_regions = properties.Property(dtype=bool, default=True, + desc='Whether to prevent inlining named control flow regions.') + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.NestedSDFGs | ppl.Modifies.States) + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.States | ppl.Modifies.NestedSDFGs + + def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: + """ + Inlines all possible nested SDFGs (and all sub-SDFGs). + + :param sdfg: The SDFG to transform. + + :return: The total number of states fused, or None if did not apply. + """ + blacklist = [] + if self.no_inline_loops: + blacklist.append(LoopRegion) + if self.no_inline_conditional: + blacklist.append(ConditionalBlock) + if self.no_inline_named_regions: + blacklist.append(NamedRegion) + if self.no_inline_function_call_regions: + blacklist.append(FunctionCallRegion) + if len(blacklist) < 1: + blacklist = None + + inlined = inline_control_flow_regions(sdfg, None, blacklist, self.progress) + return inlined or None + + def report(self, pass_retval: int) -> str: + return f'Inlined {pass_retval} regions.' + + def set_opts(self, opts): + opt_keys = [ + 'no_inline_loops', + 'no_inline_conditional', + 'no_inline_function_call_regions', + 'no_inline_named_regions', + ] + + for k in opt_keys: + if k in opts: + setattr(self, k, opts[k]) + + @dataclass(unsafe_hash=True) @properties.make_properties @experimental_cfg_block_compatible diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index c6177966f4..33d9d80b32 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -1,6 +1,6 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass -from typing import Any, Dict, Optional, Set +from typing import Any, Dict, List, Optional, Set import warnings from dace import SDFG, config, properties @@ -11,7 +11,7 @@ from dace.transformation.passes.constant_propagation import ConstantPropagation from dace.transformation.passes.dead_dataflow_elimination import DeadDataflowElimination from dace.transformation.passes.dead_state_elimination import DeadStateElimination -from dace.transformation.passes.fusion_inline import FuseStates, InlineSDFGs +from dace.transformation.passes.fusion_inline import FuseStates, InlineControlFlowRegions, InlineSDFGs from dace.transformation.passes.optional_arrays import OptionalArrayInference from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols @@ -21,6 +21,7 @@ SIMPLIFY_PASSES = [ InlineSDFGs, + InlineControlFlowRegions, ScalarToSymbolPromotion, ControlFlowRaising, FuseStates, @@ -62,15 +63,21 @@ class SimplifyPass(ppl.FixedPointPipeline): skip = properties.SetProperty(element_type=str, default=set(), desc='Set of pass names to skip.') verbose = properties.Property(dtype=bool, default=False, desc='Whether to print reports after every pass.') + no_inline_function_call_regions = properties.Property(dtype=bool, default=False, + desc='Whether to prevent inlining function call regions.') + no_inline_named_regions = properties.Property(dtype=bool, default=False, + desc='Whether to prevent inlining named control flow regions.') + def __init__(self, validate: bool = False, validate_all: bool = False, skip: Optional[Set[str]] = None, - verbose: bool = False): + verbose: bool = False, + pass_options: Optional[Dict[str, Any]] = None): if skip: - passes = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip] + passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip] else: - passes = [p() for p in SIMPLIFY_PASSES] + passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES] super().__init__(passes=passes) self.validate = validate @@ -81,6 +88,15 @@ def __init__(self, else: self.verbose = verbose + pass_opts = { + 'no_inline_function_call_regions': self.no_inline_function_call_regions, + 'no_inline_named_regions': self.no_inline_named_regions, + } + if pass_options: + pass_opts.update(pass_options) + for p in passes: + p.set_opts(pass_opts) + def apply_subpass(self, sdfg: SDFG, p: ppl.Pass, state: Dict[str, Any]): """ Apply a pass from the pipeline. This method is meant to be overridden by subclasses. diff --git a/tests/sdfg/control_flow_inline_test.py b/tests/sdfg/control_flow_inline_test.py index 87af09b9c4..3a4cfd7c13 100644 --- a/tests/sdfg/control_flow_inline_test.py +++ b/tests/sdfg/control_flow_inline_test.py @@ -19,7 +19,7 @@ def test_loop_inlining_regular_for(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong assert len(states) == 8 @@ -41,7 +41,7 @@ def test_loop_inlining_regular_while(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong guard = None @@ -75,7 +75,7 @@ def test_loop_inlining_do_while(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong guard = None @@ -115,7 +115,7 @@ def test_loop_inlining_do_for(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong guard = None @@ -175,7 +175,7 @@ def test_inline_triple_nested_for(): reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', '0:N, 0:M, 0:K')) reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', '0:N, 0:M')) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) assert len(sdfg.nodes()) == 14 assert not any(isinstance(s, LoopRegion) for s in sdfg.nodes()) @@ -203,7 +203,7 @@ def test_loop_inlining_for_continue_break(): state7 = sdfg.add_state('state7') sdfg.add_edge(loop1, state7, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong assert len(states) == 12 @@ -240,7 +240,7 @@ def test_loop_inlining_multi_assignments(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong assert len(states) == 8 @@ -282,7 +282,7 @@ def test_loop_inlining_invalid_update_statement(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) nodes = sdfg.nodes() assert len(nodes) == 3 From 26a7fbc65d0a437ebc3231b2032ef417814af8a8 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 17 Oct 2024 12:45:01 +0200 Subject: [PATCH 053/108] Fixes to deepcopy --- dace/sdfg/state.py | 2 +- dace/sdfg/utils.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 2ff947e02c..fd70c5ac6a 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1193,7 +1193,7 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k in ('_parent_graph', '_sdfg', 'guid'): # Skip derivative attributes and GUID + if k in ('_parent_graph', '_sdfg', '_cfg_list', 'guid'): # Skip derivative attributes and GUID continue setattr(result, k, copy.deepcopy(v, memo)) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index dd91678fcb..f2929e2ac6 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1278,6 +1278,9 @@ def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractCo for _block in optional_progressbar(reversed(blocks), title='Inlining control flow regions', n=len(blocks), progress=progress): block: ControlFlowRegion = _block + # Control flow regions where the parent is a conditional block are not inlined. + if block.parent_graph and type(block.parent_graph) == ConditionalBlock: + continue if block.inline()[0]: count += 1 From 1862ba4733b7c7bdea9e8b0c84f02377d28b8ac7 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 17 Oct 2024 14:19:40 +0200 Subject: [PATCH 054/108] More various fixes --- dace/codegen/targets/framecode.py | 19 ++++++++- dace/sdfg/nodes.py | 4 +- dace/sdfg/state.py | 5 ++- .../transformation/dataflow/otf_map_fusion.py | 13 ++++--- dace/transformation/pass_pipeline.py | 4 +- .../passes/analysis/analysis.py | 13 +++++++ .../passes/constant_propagation.py | 39 ++++++++++++------- .../passes/dead_dataflow_elimination.py | 24 +++++++----- tests/transformations/redundant_copy_test.py | 1 - 9 files changed, 88 insertions(+), 34 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 54273ac6b0..8073b033b6 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -9,6 +9,7 @@ import dace from dace import config, data, dtypes +from dace import symbolic from dace.cli import progress from dace.codegen import control_flow as cflow from dace.codegen import dispatcher as disp @@ -20,7 +21,7 @@ from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion from dace.transformation.passes.analysis import StateReachability, loop_analysis @@ -691,10 +692,24 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): curstate: SDFGState = None multistate = False - # Does the array appear in inter-state edges? + # Does the array appear in inter-state edges or loop / conditional block conditions etc.? for isedge in sdfg.all_interstate_edges(): if name in self.free_symbols(isedge.data): multistate = True + for cfg in sdfg.all_control_flow_regions(): + block_syms = set() + if isinstance(cfg, LoopRegion): + block_syms |= symbolic.free_symbols_and_functions(cfg.loop_condition.as_string) + if cfg.update_statement is not None: + block_syms |= symbolic.free_symbols_and_functions(cfg.update_statement.as_string) + if cfg.init_statement is not None: + block_syms |= symbolic.free_symbols_and_functions(cfg.init_statement.as_string) + elif isinstance(cfg, ConditionalBlock): + for cond, _ in cfg.branches: + if cond is not None: + block_syms |= symbolic.free_symbols_and_functions(cond.as_string) + if name in block_syms: + multistate = True for state in sdfg.states(): if multistate: diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 4ae91d5ea0..436749e70e 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -1394,8 +1394,8 @@ def expand(self, sdfg, state, *args, **kwargs) -> str: if implementation not in self.implementations.keys(): raise KeyError("Unknown implementation for node {}: {}".format(type(self).__name__, implementation)) transformation_type = type(self).implementations[implementation] - cfg_id = sdfg.cfg_id - state_id = sdfg.nodes().index(state) + cfg_id = state.parent_graph.cfg_id + state_id = state.block_id subgraph = {transformation_type._match_node: state.node_id(self)} transformation: ExpandTransformation = transformation_type() transformation.setup_match(sdfg, cfg_id, state_id, subgraph, 0) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index fd70c5ac6a..30ddab9850 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1090,6 +1090,8 @@ def all_transients(self) -> List[str]: def replace(self, name: str, new_name: str): for n in self.nodes(): n.replace(name, new_name) + for e in self.edges(): + e.data.replace(name, new_name) def replace_dict(self, repl: Dict[str, str], @@ -3083,7 +3085,8 @@ def inline(self) -> Tuple[bool, Any]: # and return are inlined correctly. def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: for block in region.nodes(): - if (isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) and not isinstance(block, LoopRegion): + if ((isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) + and not isinstance(block, LoopRegion)): recursive_inline_cf_regions(block) block.inline() recursive_inline_cf_regions(self) diff --git a/dace/transformation/dataflow/otf_map_fusion.py b/dace/transformation/dataflow/otf_map_fusion.py index a793d1e679..aa339b75c5 100644 --- a/dace/transformation/dataflow/otf_map_fusion.py +++ b/dace/transformation/dataflow/otf_map_fusion.py @@ -478,8 +478,11 @@ def advanced_replace(subgraph: StateSubgraphView, s: str, s_: str) -> None: elif isinstance(node, nodes.NestedSDFG): for nsdfg in node.sdfg.all_sdfgs_recursive(): nsdfg.replace(s, s_) - for nstate in nsdfg.nodes(): - for nnode in nstate.nodes(): - if isinstance(nnode, nodes.MapEntry): - params = [s_ if p == s else p for p in nnode.map.params] - nnode.map.params = params + for cfg in nsdfg.all_control_flow_regions(): + cfg.replace(s, s_) + for nblock in cfg.nodes(): + if isinstance(nblock, SDFGState): + for nnode in nblock.nodes(): + if isinstance(nnode, nodes.MapEntry): + params = [s_ if p == s else p for p in nnode.map.params] + nnode.map.params = params diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 30026bba6d..0174f0ff3e 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Type, Union from dataclasses import dataclass -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion class Modifies(Flag): @@ -287,6 +287,8 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D """ result = {} for region in sdfg.all_control_flow_regions(recursive=True, parent_first=False): + if isinstance(region, ConditionalBlock): + continue retval = self.apply(region, pipeline_results) if retval is not None: result[region.cfg_id] = retval diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 447efea42c..6202edb539 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -274,6 +274,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tupl top_result: Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + arrays: Set[str] = set(sdfg.arrays.keys()) for block in sdfg.all_control_flow_blocks(): readset, writeset = set(), set() if isinstance(block, SDFGState): @@ -289,6 +290,18 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tupl writeset.add(anode.data) if state.out_degree(anode) > 0: readset.add(anode.data) + if isinstance(block, LoopRegion): + exprs = set([ block.loop_condition.as_string ]) + if block.update_statement is not None: + exprs.add(block.update_statement.as_string) + if block.init_statement is not None: + exprs.add(block.init_statement.as_string) + for expr in exprs: + readset |= symbolic.free_symbols_and_functions(expr) & arrays + elif isinstance(block, ConditionalBlock): + for cond, _ in block.branches: + if cond is not None: + readset |= symbolic.free_symbols_and_functions(cond.as_string) & arrays result[block] = (readset, writeset) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index f6311dea6f..5350106672 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -263,6 +263,15 @@ def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arra out_consts[k] = _UnknownValue out_const_dict[conditional] = out_consts + def _assignments_in_loop(self, loop: LoopRegion) -> Set[str]: + assignments_within = set() + for e in loop.all_interstate_edges(): + for k in e.data.assignments.keys(): + assignments_within.add(k) + if loop.loop_variable is not None: + assignments_within.add(loop.loop_variable) + return assignments_within + def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str], in_const_dict: BlockConstsT, pre_const_dict: BlockConstsT, post_const_dict: BlockConstsT, out_const_dict: BlockConstsT) -> None: @@ -289,13 +298,8 @@ def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str] # In the case of a loop, the 'pre constants' are equivalent to the 'in constants', with the exception # of values that may at any point be re-assigned inside the loop, since that assignment would carry over # into the next iteration (including increments to the loop variable, if present). - assignments_within = set() - for e in cfg.all_interstate_edges(): - for k in e.data.assignments.keys(): - assignments_within.add(k) - if cfg.loop_variable is not None: - assignments_within.add(cfg.loop_variable) - pre_const = { k: (v if k not in assignments_within else _UnknownValue) for k, v in in_const.items() } + assigned_in_loop = self._assignments_in_loop(cfg) + pre_const = { k: (v if k not in assigned_in_loop else _UnknownValue) for k, v in in_const.items() } else: # In any other case, the 'pre constants' are equivalent to the 'in constants'. pre_const = {} @@ -359,22 +363,31 @@ def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str] reassignments = replacements & edge.data.assignments.keys() if reassignments and (used_in_assignments - reassignments): assignments[aname] = _UnknownValue + + if isinstance(block, LoopRegion): + # Any constants before a loop that may be overwritten inside the loop cannot be assumed as constants + # for the loop itself. + assigned_in_loop = self._assignments_in_loop(block) + for k in assignments.keys(): + if k in assigned_in_loop: + assignments[k] = _UnknownValue + if block not in in_const_dict: in_const_dict[block] = {} if assignments: redo |= self._propagate(in_const_dict[block], assignments) - if isinstance(block, SDFGState): - # Simple case, no change in constants through this block. - pre_const_dict[block] = in_const_dict[block] - post_const_dict[block] = in_const_dict[block] - out_const_dict[block] = in_const_dict[block] - elif isinstance(block, ControlFlowRegion): + if isinstance(block, ControlFlowRegion): self._collect_constants_for_region(block, arrays, in_const_dict, pre_const_dict, post_const_dict, out_const_dict) elif isinstance(block, ConditionalBlock): self._collect_constants_for_conditional(block, arrays, in_const_dict, pre_const_dict, post_const_dict, out_const_dict) + else: + # Simple case, no change in constants through this block (states and other basic blocks). + pre_const_dict[block] = in_const_dict[block] + post_const_dict[block] = in_const_dict[block] + out_const_dict[block] = in_const_dict[block] # For all sink nodes, compute the overlapping set of constants between them, making sure all constants in the # resulting intersection are actually constants (i.e., all blocks see the same constant value for them). This diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 964d362d99..e429fa902d 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -11,6 +11,7 @@ from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg from dace.sdfg import infer_types +from dace.sdfg.state import ControlFlowBlock from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap @@ -20,7 +21,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties @transformation.experimental_cfg_block_compatible -class DeadDataflowElimination(ppl.Pass): +class DeadDataflowElimination(ppl.ControlFlowRegionPass): """ Removes unused computations from SDFG states. Traverses the graph backwards, removing any computations that result in transient descriptors @@ -44,9 +45,9 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.CFG) def depends_on(self) -> Set[Type[ppl.Pass]]: - return {ap.StateReachability, ap.AccessSets} + return {ap.ControlFlowBlockReachability, ap.AccessSets} - def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[SDFGState, Set[str]]]: + def apply(self, region, pipeline_results): """ Removes unreachable dataflow throughout SDFG states. @@ -57,15 +58,20 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed. """ # Depends on the following analysis passes: - # * State reachability - # * Read/write access sets per state - reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results[ap.StateReachability.__name__][sdfg.cfg_id] - access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results[ap.AccessSets.__name__][sdfg.cfg_id] + # * Control flow block reachability + # * Read/write access sets per block + sdfg = region if isinstance(region, SDFG) else region.sdfg + reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[ + ap.ControlFlowBlockReachability.__name__ + ][region.cfg_id] + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[ + ap.AccessSets.__name__ + ][sdfg.cfg_id] result: Dict[SDFGState, Set[str]] = defaultdict(set) - # Traverse SDFG backwards + # Traverse region backwards try: - state_order: List[SDFGState] = list(cfg.blockorder_topological_sort(sdfg, recursive=True, + state_order: List[SDFGState] = list(cfg.blockorder_topological_sort(region, recursive=False, ignore_nonstate_blocks=True)) except KeyError: return None diff --git a/tests/transformations/redundant_copy_test.py b/tests/transformations/redundant_copy_test.py index 2c753c6fc5..280d5f182a 100644 --- a/tests/transformations/redundant_copy_test.py +++ b/tests/transformations/redundant_copy_test.py @@ -450,7 +450,6 @@ def flip_and_flatten(a, b): if __name__ == '__main__': - test_slicing_with_redundant_arrays() test_in() test_out() test_out_success() From 422edb5c5a6fb48e0014bd3ffd49a8f928c43255 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 17 Oct 2024 14:47:04 +0200 Subject: [PATCH 055/108] Yet more fixes --- dace/codegen/targets/framecode.py | 10 ++++++---- dace/sdfg/performance_evaluation/helpers.py | 4 ++-- dace/transformation/passes/analysis/__init__.py | 1 + dace/transformation/passes/analysis/analysis.py | 12 ++++++++---- tests/transformations/loop_manipulation_test.py | 12 ++++++------ 5 files changed, 23 insertions(+), 16 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 8073b033b6..62cfd03f23 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -700,10 +700,12 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): block_syms = set() if isinstance(cfg, LoopRegion): block_syms |= symbolic.free_symbols_and_functions(cfg.loop_condition.as_string) - if cfg.update_statement is not None: - block_syms |= symbolic.free_symbols_and_functions(cfg.update_statement.as_string) - if cfg.init_statement is not None: - block_syms |= symbolic.free_symbols_and_functions(cfg.init_statement.as_string) + update_stmt = loop_analysis.get_update_assignment(cfg) + init_stmt = loop_analysis.get_init_assignment(cfg) + if update_stmt: + block_syms |= symbolic.free_symbols_and_functions(update_stmt) + if init_stmt: + block_syms |= symbolic.free_symbols_and_functions(init_stmt) elif isinstance(cfg, ConditionalBlock): for cond, _ in cfg.branches: if cond is not None: diff --git a/dace/sdfg/performance_evaluation/helpers.py b/dace/sdfg/performance_evaluation/helpers.py index 552e2917cc..0272101562 100644 --- a/dace/sdfg/performance_evaluation/helpers.py +++ b/dace/sdfg/performance_evaluation/helpers.py @@ -34,9 +34,9 @@ def get_uuid(element, state=None): if isinstance(element, SDFG): return ids_to_string(element.cfg_id) elif isinstance(element, SDFGState): - return ids_to_string(element.parent.cfg_id, element.parent.node_id(element)) + return ids_to_string(element.parent_graph.cfg_id, element.block_id) elif isinstance(element, nodes.Node): - return ids_to_string(state.parent.cfg_id, state.parent.node_id(state), state.node_id(element)) + return ids_to_string(state.parent_graph.cfg_id, state.block_id, state.node_id(element)) else: return ids_to_string(-1) diff --git a/dace/transformation/passes/analysis/__init__.py b/dace/transformation/passes/analysis/__init__.py index 5bc1f6e3f3..b0b39f3c4b 100644 --- a/dace/transformation/passes/analysis/__init__.py +++ b/dace/transformation/passes/analysis/__init__.py @@ -1 +1,2 @@ from .analysis import * +from .loop_analysis import * diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 6202edb539..2ed91cea8c 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -12,6 +12,8 @@ import networkx as nx from networkx.algorithms import shortest_paths as nxsp +from dace.transformation.passes.analysis import loop_analysis + WriteScopeDict = Dict[str, Dict[Optional[Tuple[SDFGState, nd.AccessNode]], Set[Union[Tuple[SDFGState, nd.AccessNode], Tuple[ControlFlowBlock, InterstateEdge]]]]] SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], ControlFlowBlock]]]] @@ -292,10 +294,12 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tupl readset.add(anode.data) if isinstance(block, LoopRegion): exprs = set([ block.loop_condition.as_string ]) - if block.update_statement is not None: - exprs.add(block.update_statement.as_string) - if block.init_statement is not None: - exprs.add(block.init_statement.as_string) + update_stmt = loop_analysis.get_update_assignment(block) + init_stmt = loop_analysis.get_init_assignment(block) + if update_stmt: + exprs.add(update_stmt) + if init_stmt: + exprs.add(init_stmt) for expr in exprs: readset |= symbolic.free_symbols_and_functions(expr) & arrays elif isinstance(block, ConditionalBlock): diff --git a/tests/transformations/loop_manipulation_test.py b/tests/transformations/loop_manipulation_test.py index 9a3abc0239..7d87e1d2b9 100644 --- a/tests/transformations/loop_manipulation_test.py +++ b/tests/transformations/loop_manipulation_test.py @@ -27,9 +27,9 @@ def regression(A, B): def test_unroll(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 2 + assert len(sdfg.nodes()) == 1 sdfg.apply_transformations(LoopUnroll) - assert len(sdfg.nodes()) == 1 + 5 * 2 + assert len(sdfg.nodes()) == 5 * 2 sdfg.simplify() assert len(sdfg.nodes()) == 1 A = np.random.rand(20) @@ -47,9 +47,9 @@ def test_unroll(): def test_peeling_start(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 2 + assert len(sdfg.nodes()) == 1 sdfg.apply_transformations(LoopPeeling, dict(count=2)) - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 3 A = np.random.rand(20) B = np.random.rand(20) reg = regression(A, B) @@ -65,9 +65,9 @@ def test_peeling_start(): def test_peeling_end(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 2 + assert len(sdfg.nodes()) == 1 sdfg.apply_transformations(LoopPeeling, dict(count=2, begin=False)) - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 3 A = np.random.rand(20) B = np.random.rand(20) reg = regression(A, B) From a66c6108d8ea2b6a48874358e2da0c69366cfd57 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 08:47:48 +0200 Subject: [PATCH 056/108] More fixes --- dace/transformation/helpers.py | 2 +- .../interstate/fpga_transform_state.py | 5 +--- .../interstate/multistate_inline.py | 10 +++++++- dace/transformation/passes/fusion_inline.py | 13 ++++++++-- tests/python_frontend/augassign_wcr_test.py | 12 +++++----- tests/sdfg/schedule_inference_test.py | 24 +++++++++---------- 6 files changed, 40 insertions(+), 26 deletions(-) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 19bea0f0e8..dc84bd4478 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1218,7 +1218,7 @@ def traverse(state: SDFGState, treenode: ScopeTree): snodes = state.scope_children()[treenode.entry] for node in snodes: if isinstance(node, nodes.NestedSDFG): - for nstate in node.sdfg.nodes(): + for nstate in node.sdfg.states(): ntree = nstate.scope_tree()[None] ntree.state = nstate treenode.children.append(ntree) diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index 47ba478341..d92c0c058c 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -100,14 +100,11 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): wcr_input_nodes = set() stack = [] - parent_sdfg = {state: sdfg} # Map states to their parent SDFG for node, graph in state.all_nodes_recursive(): - if isinstance(graph, dace.SDFG): - parent_sdfg[node] = graph if isinstance(node, dace.sdfg.nodes.AccessNode): for e in graph.in_edges(node): if e.data.wcr is not None: - trace = dace.sdfg.trace_nested_access(node, graph, parent_sdfg[graph]) + trace = dace.sdfg.trace_nested_access(node, graph, graph.sdfg) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index c39c5868d4..75eed021c2 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -14,7 +14,7 @@ from dace.transformation import transformation, helpers from dace.properties import make_properties from dace import data -from dace.sdfg.state import StateSubgraphView +from dace.sdfg.state import LoopRegion, StateSubgraphView @make_properties @@ -190,10 +190,18 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): outer_assignments = set() for e in sdfg.all_interstate_edges(): outer_assignments |= e.data.assignments.keys() + for b in sdfg.all_control_flow_blocks(): + if isinstance(b, LoopRegion): + if b.loop_variable is not None: + outer_assignments.add(b.loop_variable) inner_assignments = set() for e in nsdfg.all_interstate_edges(): inner_assignments |= e.data.assignments.keys() + for b in nsdfg.all_control_flow_blocks(): + if isinstance(b, LoopRegion): + if b.loop_variable is not None: + inner_assignments.add(b.loop_variable) allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) assignments_to_replace = inner_assignments & (outer_assignments | allnames) diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 5cf75c6b62..006ee7e2df 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -138,8 +138,17 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: if len(blacklist) < 1: blacklist = None - inlined = inline_control_flow_regions(sdfg, None, blacklist, self.progress) - return inlined or None + inlined = 0 + while True: + inlined_in_iteration = inline_control_flow_regions(sdfg, None, blacklist, self.progress) + if inlined_in_iteration < 1: + break + inlined += inlined_in_iteration + + if inlined: + sdfg.reset_cfg_list() + return inlined + return None def report(self, pass_retval: int) -> str: return f'Inlined {pass_retval} regions.' diff --git a/tests/python_frontend/augassign_wcr_test.py b/tests/python_frontend/augassign_wcr_test.py index e6964261fe..b29ddfcff1 100644 --- a/tests/python_frontend/augassign_wcr_test.py +++ b/tests/python_frontend/augassign_wcr_test.py @@ -60,8 +60,8 @@ def test_augassign_wcr(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=True): test_sdfg = augassign_wcr.to_sdfg(simplify=False) wcr_count = 0 - for sdfg in test_sdfg.cfg_list: - for state in sdfg.nodes(): + for sdfg in test_sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.wcr: wcr_count += 1 @@ -81,8 +81,8 @@ def test_augassign_wcr2(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=True): test_sdfg = augassign_wcr2.to_sdfg(simplify=False) wcr_count = 0 - for sdfg in test_sdfg.cfg_list: - for state in sdfg.nodes(): + for sdfg in test_sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.wcr: wcr_count += 1 @@ -105,8 +105,8 @@ def test_augassign_wcr3(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=True): test_sdfg = augassign_wcr3.to_sdfg(simplify=False) wcr_count = 0 - for sdfg in test_sdfg.cfg_list: - for state in sdfg.nodes(): + for sdfg in test_sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.wcr: wcr_count += 1 diff --git a/tests/sdfg/schedule_inference_test.py b/tests/sdfg/schedule_inference_test.py index 1b1b3422d8..8f4fcd6acb 100644 --- a/tests/sdfg/schedule_inference_test.py +++ b/tests/sdfg/schedule_inference_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Tests for default storage/schedule inference. """ import dace +from dace.sdfg.state import SDFGState from dace.sdfg.validation import InvalidSDFGNodeError from dace.sdfg.infer_types import set_default_schedule_and_storage_types from dace.transformation.helpers import get_parent_map @@ -95,10 +96,10 @@ def top(a: dace.float64[20, 20], b: dace.float64[20, 20]): sdfg = top.to_sdfg(simplify=False) set_default_schedule_and_storage_types(sdfg, None) - for node, state in sdfg.all_nodes_recursive(): - nsdfg = state.parent - if isinstance(node, dace.nodes.AccessNode): - assert node.desc(nsdfg).storage == dace.StorageType.CPU_Heap + for sd in sdfg.all_sdfgs_recursive(): + for state in sd.states(): + for dn in state.data_nodes(): + assert dn.desc(sd).storage == dace.StorageType.CPU_Heap def test_nested_storage_equivalence(): @@ -114,13 +115,13 @@ def top(a: dace.float64[20, 20] @ dace.StorageType.CPU_Heap, b: dace.float64[20, sdfg = top.to_sdfg(simplify=False) set_default_schedule_and_storage_types(sdfg, None) - for node, state in sdfg.all_nodes_recursive(): - nsdfg = state.parent - if isinstance(node, dace.nodes.AccessNode): - if state.out_degree(node) > 0: # Check for a in external and internal scopes - assert node.desc(nsdfg).storage == dace.StorageType.CPU_Heap - elif state.in_degree(node) > 0: # Check for b in external and internal scopes - assert node.desc(nsdfg).storage == dace.StorageType.CPU_Pinned + for sd in sdfg.all_sdfgs_recursive(): + for state in sd.states(): + for dn in state.data_nodes(): + if state.out_degree(dn) > 0: # Check for a in external and internal scopes + assert dn.desc(sd).storage == dace.StorageType.CPU_Heap + elif state.in_degree(dn) > 0: # Check for b in external and internal scopes + assert dn.desc(sd).storage == dace.StorageType.CPU_Pinned def test_ambiguous_schedule(): @@ -171,7 +172,6 @@ def add(a: dace.float32[10, 10] @ dace.StorageType.GPU_Global, test_gpu_schedule_autodetect() test_gpu_schedule_scalar_autodetect() test_gpu_schedule_scalar_autodetect_2() - test_nested_kernel_computation() test_nested_map_in_loop_schedule() test_nested_storage() test_nested_storage_equivalence() From b507a4b54896cc0b70644b0e5e6f11b57e6af222 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 10:10:52 +0200 Subject: [PATCH 057/108] Fixes --- dace/sdfg/state.py | 5 +++- .../interstate/fpga_transform_state.py | 6 ++-- dace/transformation/interstate/loop_to_map.py | 2 +- dace/transformation/subgraph/composite.py | 8 ++--- tests/sdfg/work_depth_test.py | 30 +++++++++++++++---- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 30ddab9850..083513005f 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3192,7 +3192,9 @@ def _used_symbols_internal(self, free_syms |= self.init_statement.get_free_symbols() if self.update_statement is not None: free_syms |= self.update_statement.get_free_symbols() - free_syms |= self.loop_condition.get_free_symbols() + cond_free_syms = self.loop_condition.get_free_symbols() + if self.loop_variable and self.loop_variable in cond_free_syms: + cond_free_syms.remove(self.loop_variable) b_free_symbols, b_defined_symbols, b_used_before_assignment = super()._used_symbols_internal( all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) @@ -3203,6 +3205,7 @@ def _used_symbols_internal(self, defined_syms -= used_before_assignment free_syms -= defined_syms + free_syms |= cond_free_syms return free_syms, defined_syms, used_before_assignment diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index d92c0c058c..6e1af4ed16 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -100,11 +100,11 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): wcr_input_nodes = set() stack = [] - for node, graph in state.all_nodes_recursive(): + for node, pGraph in state.all_nodes_recursive(): if isinstance(node, dace.sdfg.nodes.AccessNode): - for e in graph.in_edges(node): + for e in pGraph.in_edges(node): if e.data.wcr is not None: - trace = dace.sdfg.trace_nested_access(node, graph, graph.sdfg) + trace = dace.sdfg.trace_nested_access(node, pGraph, pGraph.sdfg) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 496adc238f..5a50c54c45 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -394,7 +394,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Create NestedSDFG and add the loop contents to it. Gaher symbols defined in the NestedSDFG. fsymbols = set(sdfg.free_symbols) - body = graph.add_state('single_state_body') + body = graph.add_state('single_state_body', is_start_block=(graph.start_block is self.loop)) nsdfg = SDFG('loop_body', constants=sdfg.constants_prop, parent=body) nsdfg.add_node(self.loop.start_block, is_start_block=True) nsymbols = dict() diff --git a/dace/transformation/subgraph/composite.py b/dace/transformation/subgraph/composite.py index 886698f791..eda247e64f 100644 --- a/dace/transformation/subgraph/composite.py +++ b/dace/transformation/subgraph/composite.py @@ -64,9 +64,10 @@ def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: graph_indices = [i for (i, n) in enumerate(graph.nodes()) if n in subgraph] sdfg_copy = copy.deepcopy(sdfg) sdfg_copy.reset_cfg_list() - graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)] + par_graph_copy = sdfg_copy.cfg_list[graph.parent_graph.cfg_id] + graph_copy = par_graph_copy.nodes()[graph.block_id] subgraph_copy = SubgraphView(graph_copy, [graph_copy.nodes()[i] for i in graph_indices]) - expansion.cfg_id = sdfg_copy.cfg_id + expansion.cfg_id = par_graph_copy.cfg_id ##sdfg_copy.apply_transformations(MultiExpansion, states=[graph]) #expansion = MultiExpansion() @@ -100,9 +101,6 @@ def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: def apply(self, sdfg): subgraph = self.subgraph_view(sdfg) graph = subgraph.graph - scope_dict = graph.scope_dict() - map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph, scope_dict) - first_entry = next(iter(map_entries)) if self.allow_expansion: expansion = MultiExpansion() diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index e677cca752..11873fa03d 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ from typing import Dict, List, Tuple @@ -18,6 +18,8 @@ from pytest import raises +from dace.transformation.passes.fusion_inline import InlineControlFlowRegions + N = dc.symbol('N') M = dc.symbol('M') K = dc.symbol('K') @@ -192,9 +194,9 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_4') * N, sp.Symbol('num_execs_0_4'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_4') * N, 2 * sp.Symbol('num_execs_0_4'))), 'break_for_loop': (break_for_loop, (N**2, N)), 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), @@ -217,6 +219,15 @@ def test_work_depth(test_name): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) + + # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] # substitue each symbol without assumptions. @@ -264,6 +275,15 @@ def test_avg_par(test_name: str): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test_name: sdfg.apply_transformations(MapExpansion) + + # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] # substitue each symbol without assumptions. @@ -320,8 +340,8 @@ def test_assumption_system_contradictions(assumptions): for test_name in work_depth_test_cases.keys(): test_work_depth(test_name) - for test, correct in tests_cases_avg_par: - test_avg_par(test, correct) + for test_name in tests_cases_avg_par.keys(): + test_avg_par(test_name) for expr, assums, res in assumptions_tests: test_assumption_system(expr, assums, res) From 4084dfe9eaf11eafd59b85ed9ec47abca71e66cc Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 11:53:36 +0200 Subject: [PATCH 058/108] More fixes --- dace/transformation/pass_pipeline.py | 26 -------- .../passes/analysis/analysis.py | 2 +- .../simplification/control_flow_raising.py | 2 +- dace/transformation/passes/simplify.py | 2 + .../control_flow_raising_test.py | 37 ++++++++++- .../symbol_write_scopes_analysis_test.py | 4 +- .../writeset_underapproximation_test.py | 65 +++++++++++++++++++ 7 files changed, 105 insertions(+), 33 deletions(-) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 0174f0ff3e..d8bd8745ff 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -547,35 +547,9 @@ def apply_subpass(self, sdfg: SDFG, p: Pass, state: Dict[str, Any]) -> Optional[ :param state: The pipeline results state. :return: The pass return value. """ - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(p, '__experimental_cfg_block_compatible__') or - p.__experimental_cfg_block_compatible__ == False): - warnings.warn(p.__class__.__name__ + ' is not being applied due to incompatibility with ' + - 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + - 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + - 'True. If ' + p.__class__.__name__ + ' is compatible with experimental blocks, ' + - 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + - '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + - 'for more information.') - return None - return p.apply_pass(sdfg, state) def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Any]]: - if sdfg.root_sdfg.using_experimental_blocks: - if (type(self) != Pipeline and (not hasattr(self, '__experimental_cfg_block_compatible__') or - self.__experimental_cfg_block_compatible__ == False)): - warnings.warn('Pipeline ' + self.__class__.__name__ + ' is being skipped due to incompatibility with ' + - 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + - 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + - 'True. If ' + self.__class__.__name__ + ' is compatible with experimental blocks, ' + - 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + - '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + - 'for more information.') - return None - state = pipeline_results retval = {} self._modified = Modifies.Nothing diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 2ed91cea8c..bc1bac4640 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -462,7 +462,7 @@ def apply(self, region, pipeline_results) -> SymbolScopeDict: for read_loc, (reads, _) in symbol_access_sets.items(): for sym in reads: dominating_write = self._find_dominating_write(sym, read_loc, idom) - result[sym][dominating_write].add(read_loc if isinstance(read_loc, ControlFlowBlock) else read_loc.dst) + result[sym][dominating_write].add(read_loc) # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, then merge # A and its scope into B's scope. diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 12a2b2908a..fa1a3c6f97 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -5,7 +5,7 @@ from dace import properties from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion from dace.sdfg.utils import dfs_conditional from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.interstate.loop_lifting import LoopLifting diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 33d9d80b32..d3e8b580da 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -123,6 +123,8 @@ def apply_subpass(self, sdfg: SDFG, p: ppl.Pass, state: Dict[str, Any]): ret = ret or None else: ret = p.apply_pass(sdfg, state) + if ret is not None: + sdfg.reset_cfg_list() if self.verbose: if ret is not None: diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py index 53e01df12f..22701fb2bb 100644 --- a/tests/passes/simplification/control_flow_raising_test.py +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -4,6 +4,7 @@ import numpy as np from dace.sdfg.state import ConditionalBlock from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline +from dace.transformation.passes.fusion_inline import InlineControlFlowRegions from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising @@ -19,10 +20,17 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): sdfg = dataflow_if_check.to_sdfg() + # To test raising, we inline the control flow generated by the frontend. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) ppl = FixedPointPipeline([ControlFlowRaising()]) - ppl.__experimental_cfg_block_compatible__ = True ppl.apply_pass(sdfg, {}) assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -56,8 +64,21 @@ def nested_if_chain(i: dace.int64): sdfg = nested_if_chain.to_sdfg() + # To test raising, we inline the control flow generated by the frontend. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.apply_pass(sdfg, {}) + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + assert nested_if_chain(0)[0] == 0 assert nested_if_chain(2)[0] == 1 assert nested_if_chain(4)[0] == 2 @@ -80,9 +101,21 @@ def elif_chain(i: dace.int64): else: return 4 - elif_chain.use_experimental_cfg_blocks = True sdfg = elif_chain.to_sdfg() + # To test raising, we inline the control flow generated by the frontend. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.apply_pass(sdfg, {}) + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) assert elif_chain(0)[0] == 0 diff --git a/tests/passes/symbol_write_scopes_analysis_test.py b/tests/passes/symbol_write_scopes_analysis_test.py index 8450841729..0f3207a262 100644 --- a/tests/passes/symbol_write_scopes_analysis_test.py +++ b/tests/passes/symbol_write_scopes_analysis_test.py @@ -1,8 +1,6 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the symbol write scopes analysis pass. """ -import pytest - import dace from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis import SymbolWriteScopes, SymbolScopeDict diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 96df87b5e7..0e6f9d4fb4 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -5,6 +5,7 @@ from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDict from dace.subsets import Range from dace.transformation.pass_pipeline import Pipeline +from dace.transformation.passes.fusion_inline import InlineControlFlowRegions N = dace.symbol("N") M = dace.symbol("M") @@ -307,6 +308,14 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -331,6 +340,14 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -456,6 +473,14 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] write_approx = result[sdfg.cfg_id].approximation @@ -491,6 +516,14 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -524,6 +557,14 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -558,6 +599,14 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) result: Dict[int, UnderapproximateWritesDict] = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -824,6 +873,14 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -854,6 +911,14 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inliner = InlineControlFlowRegions() + inliner.no_inline_conditional = False + inliner.no_inline_loops = False + inliner.no_inline_function_call_regions = False + inliner.no_inline_named_regions = False + inliner.apply_pass(sdfg, {}) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] From d0106207d33972e7c89c2d96e89ab98a182deb5b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 14:16:37 +0200 Subject: [PATCH 059/108] More fixes --- dace/codegen/control_flow.py | 4 +- .../analysis/schedule_tree/sdfg_to_tree.py | 2 +- dace/sdfg/replace.py | 12 +++ dace/sdfg/state.py | 4 +- dace/transformation/helpers.py | 98 +++++++++++-------- dace/transformation/pass_pipeline.py | 6 +- .../passes/analysis/analysis.py | 17 +++- .../passes/dead_state_elimination.py | 2 +- .../simplification/control_flow_raising.py | 1 + .../prune_empty_conditional_branches.py | 9 +- dace/transformation/passes/simplify.py | 5 + tests/fortran/array_test.py | 8 +- tests/passes/dead_code_elimination_test.py | 4 +- .../prune_empty_conditional_branches_test.py | 4 +- .../python_frontend/function_regions_test.py | 23 ++--- tests/python_frontend/named_region_test.py | 25 ++--- tests/schedule_tree/nesting_test.py | 7 +- tests/schedule_tree/schedule_test.py | 2 + 18 files changed, 151 insertions(+), 82 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 6657d09808..fdba40526d 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -62,8 +62,8 @@ import sympy as sp from dace import dtypes from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, - ReturnBlock, SDFGState) +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, + LoopRegion, ReturnBlock, SDFGState) from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.graph import Edge from dace.properties import CodeBlock diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9357ca3db9..84f36189b3 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -652,7 +652,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) ############################# # Create initial tree from CFG - cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') + cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '') # Traverse said tree (also into states) to create the schedule tree def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 9b6086098e..83c5e5c148 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -11,6 +11,7 @@ from dace import dtypes, properties, symbolic from dace.codegen import cppunparse from dace.frontend.python.astutils import ASTFindReplace +from dace.sdfg.state import ConditionalBlock, LoopRegion if TYPE_CHECKING: from dace.sdfg.state import StateSubgraphView @@ -200,3 +201,14 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]): for edge in block.edges(): if edge.data.data in repl: edge.data.data = repl[edge.data.data] + + # Replace in loop or branch conditions: + if isinstance(cf, LoopRegion): + replace_in_codeblock(cf.loop_condition, repl) + if cf.update_statement: + replace_in_codeblock(cf.update_statement, repl) + if cf.init_statement: + replace_in_codeblock(cf.init_statement, repl) + elif isinstance(cf, ConditionalBlock): + for c, _ in cf.branches: + replace_in_codeblock(c, repl) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 083513005f..ca733258df 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -13,7 +13,6 @@ import dace from dace.frontend.python import astutils -from dace.sdfg.replace import replace_in_codeblock import dace.serialize from dace import data as dt from dace import dtypes @@ -3319,6 +3318,9 @@ def replace_dict(self, symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, replace_in_graph: bool = True, replace_keys: bool = True): + # Avoid circular imports + from dace.sdfg.replace import replace_in_codeblock + if replace_keys: from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index dc84bd4478..c6a701bd48 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -4,10 +4,10 @@ import itertools from networkx import MultiDiGraph -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.subsets import Range, Subset, union import dace.subsets as subsets -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set, Union +from typing import Dict, Iterable, List, Optional, Tuple, Set, Union from dace import data, dtypes, symbolic from dace.codegen import control_flow as cf @@ -30,10 +30,13 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS """ # Nest states - states = subgraph.nodes() + blocks: List[ControlFlowBlock] = subgraph.nodes() return_state = None - if len(states) > 1: + if len(blocks) > 1: + # Avoid cyclic imports + from dace.transformation.passes.analysis import loop_analysis + graph: ControlFlowRegion = blocks[0].parent_graph if start is not None: source_node = start else: @@ -48,6 +51,22 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS raise NotImplementedError sink_node = sink_nodes[0] + all_blocks: List[ControlFlowBlock] = [] + is_edges: List[Edge[InterstateEdge]] = [] + for b in blocks: + if isinstance(b, AbstractControlFlowRegion): + for nb in b.all_control_flow_blocks(): + all_blocks.append(nb) + for e in b.all_interstate_edges(): + is_edges.append(e) + else: + all_blocks.append(b) + states: List[SDFGState] = [b for b in all_blocks if isinstance(b, SDFGState)] + for src in blocks: + for dst in blocks: + for edge in graph.edges_between(src, dst): + is_edges.append(edge) + # Find read/write sets read_set, write_set = set(), set() for state in states: @@ -67,12 +86,10 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS if e.data.data and e.data.data in sdfg.arrays: write_set.add(e.data.data) # Add data from edges - for src in states: - for dst in states: - for edge in sdfg.edges_between(src, dst): - for s in edge.data.free_symbols: - if s in sdfg.arrays: - read_set.add(s) + for edge in is_edges: + for s in edge.data.free_symbols: + if s in sdfg.arrays: + read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set @@ -82,7 +99,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS continue found = False for state in sdfg.states(): - if state in states: + if state in blocks: continue for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == name): @@ -98,7 +115,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Find defined subgraph symbols defined_symbols = set() strictly_defined_symbols = set() - for e in subgraph.edges(): + for e in is_edges: defined_symbols.update(set(e.data.assignments.keys())) for k, v in e.data.assignments.items(): try: @@ -107,22 +124,30 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS except AttributeError: # `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args` pass - - return_state = new_state = sdfg.add_state('nested_sdfg_parent') + for b in all_blocks: + if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '': + defined_symbols.update(b.loop_variable) + if b.loop_variable not in sdfg.symbols: + if b.init_statement: + init_assignment = loop_analysis.get_init_assignment(b) + if b.loop_variable not in {str(s) for s in symbolic.pystr_to_symbolic(init_assignment).args}: + strictly_defined_symbols.add(b.loop_variable) + else: + strictly_defined_symbols.add(b.loop_variable) + + return_state = new_state = graph.add_state('nested_sdfg_parent') nsdfg = SDFG("nested_sdfg", constants=sdfg.constants_prop, parent=new_state) nsdfg.add_node(source_node, is_start_state=True) - nsdfg.add_nodes_from([s for s in states if s is not source_node]) - for s in states: - s.parent = nsdfg + nsdfg.add_nodes_from([s for s in blocks if s is not source_node]) for e in subgraph.edges(): nsdfg.add_edge(e.src, e.dst, e.data) - for e in sdfg.in_edges(source_node): - sdfg.add_edge(e.src, new_state, e.data) - for e in sdfg.out_edges(sink_node): - sdfg.add_edge(new_state, e.dst, e.data) + for e in graph.in_edges(source_node): + graph.add_edge(e.src, new_state, e.data) + for e in graph.out_edges(sink_node): + graph.add_edge(new_state, e.dst, e.data) - sdfg.remove_nodes_from(states) + graph.remove_nodes_from(blocks) # Add NestedSDFG arrays for name in read_set | write_set: @@ -177,15 +202,15 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Part (2) if out_state is not None: - extra_state = sdfg.add_state('symbolic_output') - for e in sdfg.out_edges(new_state): - sdfg.add_edge(extra_state, e.dst, e.data) - sdfg.remove_edge(e) - sdfg.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping)) + extra_state = graph.add_state('symbolic_output') + for e in graph.out_edges(new_state): + graph.add_edge(extra_state, e.dst, e.data) + graph.remove_edge(e) + graph.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping)) new_state = extra_state else: - return_state = states[0] + return_state = blocks[0] return return_state @@ -244,7 +269,8 @@ def _copy_state(sdfg: SDFG, return state_copy -def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, Set[ControlFlowBlock]]: +def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, + Tuple[Set[ControlFlowBlock], ControlFlowBlock]]: """ Partitions a CFG to subgraphs that can be nested independently of each other. The method does not nest the subgraphs but alters the graph; (1) interstate edges are split, (2) scope source/sink nodes that belong to multiple @@ -352,16 +378,10 @@ def nest_sdfg_control_flow(sdfg: SDFG, components=None): :param sdfg: The SDFG to be partitioned. :param components: An existing partition of the SDFG. """ - - components = components or find_sdfg_control_flow(sdfg) - - num_components = len(components) - - if num_components < 2: - return - - for i, (start, (component, _)) in enumerate(components.items()): - nest_sdfg_subgraph(sdfg, graph.SubgraphView(sdfg, component), start) + regions = list(sdfg.all_control_flow_regions()) + for region in regions: + nest_sdfg_subgraph(region.sdfg, SubgraphView(region.sdfg, [region]), region) + sdfg.reset_cfg_list() def nest_state_subgraph(sdfg: SDFG, diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index d8bd8745ff..e558ab0b20 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -274,6 +274,10 @@ class ControlFlowRegionPass(Pass): CATEGORY: str = 'Helper' + apply_to_conditionals = properties.Property(dtype=bool, default=False, + desc='Whether or not to apply to conditional blocks. If false, do ' + + 'not apply to conditional blocks, but only their children.') + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[int, Optional[Any]]]: """ Applies the pass to control flow regions of the given SDFG by calling ``apply`` on each region. @@ -287,7 +291,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D """ result = {} for region in sdfg.all_control_flow_regions(recursive=True, parent_first=False): - if isinstance(region, ConditionalBlock): + if isinstance(region, ConditionalBlock) and not self.apply_to_conditionals: continue retval = self.apply(region, pipeline_results) if retval is not None: diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index bc1bac4640..720b0b8b5b 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -153,7 +153,12 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for n, v in reachable_nodes(cfg.nx): - single_level_reachable[cfg.cfg_id][n] = set(v) + reach = set() + for nd in v: + reach.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + reach.update(nd.all_control_flow_blocks()) + single_level_reachable[cfg.cfg_id][n] = reach if isinstance(cfg, LoopRegion): single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) @@ -166,7 +171,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) for block in cfg.nodes(): for reached in single_level_reachable[block.parent_graph.cfg_id][block]: - if isinstance(reached, ControlFlowRegion): + if isinstance(reached, AbstractControlFlowRegion): result[block].update(reached.all_control_flow_blocks()) result[block].add(reached) if block.parent_graph is not sdfg: @@ -516,7 +521,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.States def depends_on(self): - return {AccessSets, FindAccessNodes, StateReachability} + return {AccessSets, FindAccessNodes, ControlFlowBlockReachability} def _find_dominating_write(self, desc: str, @@ -615,7 +620,9 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[ FindAccessNodes.__name__][sdfg.cfg_id] - state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id] + block_reach: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[ + ControlFlowBlockReachability.__name__ + ] anames = sdfg.arrays.keys() for desc in sdfg.arrays: @@ -657,7 +664,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i continue write_state, write_node = write dominators = all_doms_transitive[write_state] - reach = state_reach[write_state] + reach = block_reach[write_state.parent_graph.cfg_id][write_state] for other_write, other_accesses in result[desc].items(): if other_write is not None and other_write[1] is write_node and other_write[0] is write_state: continue diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index cda193f43a..23f2a785f5 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -77,7 +77,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters cfg.remove_node(node) else: result.add(node) - cfg.remove_node(block) + cfg.remove_node(node) if not annotated: return result or None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index fa1a3c6f97..b852b798b1 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -97,4 +97,5 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: lifted_branches += self._lift_conditionals(sdfg) if lifted_branches == 0 and lifted_loops == 0: return None + top_sdfg.reset_cfg_list() return lifted_loops, lifted_branches diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py index 111944614c..29a400d5d1 100644 --- a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -14,6 +14,10 @@ class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass): CATEGORY: str = 'Simplification' + def __init__(self): + super().__init__() + self.apply_to_conditionals = True + def modifies(self) -> ppl.Modifies: return ppl.Modifies.CFG @@ -58,5 +62,8 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]: region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge()) region.parent_graph.remove_node(region) - return removed_branches if removed_branches > 0 else None + if removed_branches > 0: + region.reset_cfg_list() + return removed_branches + return None diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index d3e8b580da..97eb383764 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -73,6 +73,8 @@ def __init__(self, validate_all: bool = False, skip: Optional[Set[str]] = None, verbose: bool = False, + no_inline_function_call_regions: bool = False, + no_inline_named_regions: bool = False, pass_options: Optional[Dict[str, Any]] = None): if skip: passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip] @@ -88,6 +90,9 @@ def __init__(self, else: self.verbose = verbose + self.no_inline_function_call_regions = no_inline_function_call_regions + self.no_inline_named_regions = no_inline_named_regions + pass_opts = { 'no_inline_function_call_regions': self.no_inline_function_call_regions, 'no_inline_named_regions': self.no_inline_named_regions, diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index a8ece680a6..3283d2e37f 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -17,6 +17,7 @@ import dace.frontend.fortran.ast_transforms as ast_transforms import dace.frontend.fortran.ast_utils as ast_utils import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +from dace.sdfg.state import LoopRegion def test_fortran_frontend_array_access(): @@ -199,9 +200,10 @@ def test_fortran_frontend_memlet_in_map_test(): """ sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test") sdfg.simplify() - # Expect that start is begin of for loop -> only one out edge to guard defining iterator variable - assert len(sdfg.out_edges(sdfg.start_state)) == 1 - iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0]) + # Expect that the start block is a loop + loop = sdfg.nodes()[0] + assert isinstance(loop, LoopRegion) + iter_var = symbolic.pystr_to_symbolic(loop.loop_variable) for state in sdfg.states(): if len(state.nodes()) > 1: diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 14d380c463..bf40ff4409 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -265,12 +265,12 @@ def dce_tester(a: dace.float64[20], b: dace.float64[20]): sdfg = dce_tester.to_sdfg(simplify=False) result = Pipeline([DeadDataflowElimination(), DeadStateElimination()]).apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.number_of_nodes() <= 6 + assert sdfg.number_of_nodes() <= 4 # Check that arrays were removed assert all('c' not in [n.data for n in state.data_nodes()] for state in sdfg.nodes()) assert any('f' in [n.data for n in rstate if isinstance(n, dace.nodes.AccessNode)] - for rstate in result['DeadDataflowElimination'].values()) + for rstate in result[DeadDataflowElimination.__name__][0].values()) def test_dce_callback(): diff --git a/tests/passes/simplification/prune_empty_conditional_branches_test.py b/tests/passes/simplification/prune_empty_conditional_branches_test.py index 65463ad3a7..dc25cdc670 100644 --- a/tests/passes/simplification/prune_empty_conditional_branches_test.py +++ b/tests/passes/simplification/prune_empty_conditional_branches_test.py @@ -36,7 +36,7 @@ def prune_empty_else(A: dace.int32[N]): res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) - assert res[conditional] == 1 + assert res[conditional.cfg_id] == 1 assert len(conditional.branches) == 1 N1 = 32 @@ -82,7 +82,7 @@ def prune_empty_if_with_else(A: dace.int32[N]): res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) - assert res[conditional] == 1 + assert res[conditional.cfg_id] == 1 assert len(conditional.branches) == 1 assert conditional.branches[0][0] is not None diff --git a/tests/python_frontend/function_regions_test.py b/tests/python_frontend/function_regions_test.py index c5c9b4ac6f..5d5082a92e 100644 --- a/tests/python_frontend/function_regions_test.py +++ b/tests/python_frontend/function_regions_test.py @@ -3,6 +3,7 @@ import numpy as np import dace from dace.sdfg.state import FunctionCallRegion +from dace.transformation.passes.simplify import SimplifyPass def test_function_call(): N = dace.symbol("N") @@ -11,9 +12,9 @@ def func(A: dace.float64[N]): @dace.program def prog(I: dace.float64[N]): return func(I) - prog.use_experimental_cfg_blocks = True - sdfg = prog.to_sdfg() - call_region: FunctionCallRegion = sdfg.nodes()[1] + sdfg = prog.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call_region: FunctionCallRegion = sdfg.nodes()[0] assert call_region.arguments == {'A': 'I'} assert sdfg(np.array([+1], dtype=np.float64), N=1) == 15 assert sdfg(np.array([-1], dtype=np.float64), N=1) == 5 @@ -26,13 +27,13 @@ def func(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): def prog(E: dace.float64[N], F: dace.float64[N], G: dace.float64[N]): func(A=E, B=F, C=G) func(A=G, B=E, C=E) - prog.use_experimental_cfg_blocks = True E = np.array([1]) F = np.array([2]) G = np.array([3]) - sdfg = prog.to_sdfg(E=E, F=F, G=G, N=1) - call1: FunctionCallRegion = sdfg.nodes()[1] - call2: FunctionCallRegion = sdfg.nodes()[2] + sdfg = prog.to_sdfg(E=E, F=F, G=G, N=1, simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call1: FunctionCallRegion = sdfg.nodes()[0] + call2: FunctionCallRegion = sdfg.nodes()[1] assert call1.arguments == {'A': 'E', 'B': 'F', 'C': 'G'} assert call2.arguments == {'A': 'G', 'B': 'E', 'C': 'E'} @@ -44,10 +45,10 @@ def func(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): def prog(): func(A=np.array([1]), B=np.array([2]), C=np.array([3])) func(A=np.array([3]), B=np.array([1]), C=np.array([1])) - prog.use_experimental_cfg_blocks = True - sdfg = prog.to_sdfg(N=1) - call1: FunctionCallRegion = sdfg.nodes()[1] - call2: FunctionCallRegion = sdfg.nodes()[2] + sdfg = prog.to_sdfg(N=1, simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call1: FunctionCallRegion = sdfg.nodes()[0] + call2: FunctionCallRegion = sdfg.nodes()[1] assert call1.arguments == {'A': '__tmp0', 'B': '__tmp1', 'C': '__tmp2'} assert call2.arguments == {'A': '__tmp4', 'B': '__tmp5', 'C': '__tmp6'} diff --git a/tests/python_frontend/named_region_test.py b/tests/python_frontend/named_region_test.py index f9be206bca..593fde5c0f 100644 --- a/tests/python_frontend/named_region_test.py +++ b/tests/python_frontend/named_region_test.py @@ -3,6 +3,7 @@ import numpy as np import dace from dace.sdfg.state import NamedRegion +from dace.transformation.passes.simplify import SimplifyPass def test_named_region_no_name(): @@ -11,21 +12,21 @@ def func(A: dace.float64[1]): with dace.named: A[0] = 20 return A - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - named_region = sdfg.reset_cfg_list()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + named_region = sdfg.nodes()[0] assert isinstance(named_region, NamedRegion) A = np.zeros(shape=(1,)) - assert func(A) == 20 + assert sdfg(A) == 20 def test_named_region_with_name(): @dace.program def func(): with dace.named("my named region"): pass - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - named_region: NamedRegion = sdfg.reset_cfg_list()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + named_region: NamedRegion = sdfg.nodes()[0] assert named_region.label == "my named region" def test_nested_named_regions(): @@ -35,13 +36,13 @@ def func(): with dace.named("middle region"): with dace.named("inner region"): pass - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - outer: NamedRegion = sdfg.nodes()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + outer: NamedRegion = sdfg.nodes()[0] assert outer.label == "outer region" - middle: NamedRegion = outer.nodes()[1] + middle: NamedRegion = outer.nodes()[0] assert middle.label == "middle region" - inner: NamedRegion = middle.nodes()[1] + inner: NamedRegion = middle.nodes()[0] assert inner.label == "inner region" if __name__ == "__main__": diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index 161f15d6c1..8361ecb149 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -5,6 +5,7 @@ import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.sdfg.utils import inline_control_flow_regions from dace.transformation.dataflow import RemoveSliceView import pytest @@ -63,7 +64,8 @@ def tester(A: dace.float64[N, N]): if simplified: assert [type(n) - for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] + for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.GeneralLoopScope, + tn.TaskletNode] tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] @@ -127,6 +129,7 @@ def tester(a: dace.float64[40], b: dace.float64[40]): nester(b[1:21], a[10:30]) sdfg = tester.to_sdfg(simplify=False) + inline_control_flow_regions(sdfg) sdfg.apply_transformations_repeated(RemoveSliceView) stree = as_schedule_tree(sdfg) @@ -150,6 +153,7 @@ def tester(a: dace.float64[40]): nester(a[1:21], a[10:30]) sdfg = tester.to_sdfg(simplify=False) + inline_control_flow_regions(sdfg) sdfg.apply_transformations_repeated(RemoveSliceView) stree = as_schedule_tree(sdfg) @@ -176,6 +180,7 @@ def tester(a: dace.float64[N, N]): nester1(a[:, 1]) sdfg = tester.to_sdfg(simplify=simplify) + inline_control_flow_regions(sdfg) stree = as_schedule_tree(sdfg) # Simplifying yields a different SDFG due to views, so testing is slightly different diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 1bf2962cb3..542ff425dc 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -5,6 +5,8 @@ from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree import numpy as np +from dace.sdfg.utils import inline_control_flow_regions + def test_for_in_map_in_for(): From a047d372b2dea95dc7a3d84d9a86ae188953656b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 15:06:00 +0200 Subject: [PATCH 060/108] And yet more --- dace/sdfg/replace.py | 3 +- .../transformation/passes/scalar_to_symbol.py | 29 ++++++++++++ tests/passes/scalar_to_symbol_test.py | 45 +++++++------------ tests/sdfg/work_depth_test.py | 6 +-- 4 files changed, 51 insertions(+), 32 deletions(-) diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 83c5e5c148..b49f13cee6 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -211,4 +211,5 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]): replace_in_codeblock(cf.init_statement, repl) elif isinstance(cf, ConditionalBlock): for c, _ in cf.branches: - replace_in_codeblock(c, repl) + if c is not None: + replace_in_codeblock(c, repl) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index a0cb08ea0c..afaa4319da 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -21,6 +21,7 @@ from dace.sdfg import utils as sdutils from dace.sdfg.replace import replace_properties_dict from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.transformation import helpers as xfh from dace.transformation import pass_pipeline as passes from dace.transformation.transformation import experimental_cfg_block_compatible @@ -228,6 +229,19 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer interstate_symbols = set() for edge in sdfg.all_interstate_edges(): interstate_symbols |= edge.data.free_symbols + for reg in sdfg.all_control_flow_regions(): + if isinstance(reg, LoopRegion): + interstate_symbols |= reg.loop_condition.get_free_symbols() + if reg.loop_variable: + interstate_symbols.add(reg.loop_variable) + if reg.update_statement: + interstate_symbols |= reg.update_statement.get_free_symbols() + if reg.init_statement: + interstate_symbols |= reg.init_statement.get_free_symbols() + elif isinstance(reg, ConditionalBlock): + for c, _ in reg.branches: + if c is not None: + interstate_symbols |= c.get_free_symbols() for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[candidate].dtype not in dtypes.INTEGER_TYPES: candidates.remove(candidate) @@ -722,6 +736,21 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # should work for all Python versions. assignment = cleanup_re[scalar].sub(scalar, assignment.strip()) ise.assignments[aname] = assignment + for reg in sdfg.all_control_flow_regions(): + if isinstance(reg, LoopRegion): + codes = [reg.loop_condition] + if reg.init_statement: + codes.append(reg.init_statement) + if reg.update_statement: + codes.append(reg.update_statement) + for cd in codes: + for stmt in cd.code: + promo.visit(stmt) + elif isinstance(reg, ConditionalBlock): + for c, _ in reg.branches: + if c is not None: + for stmt in c.code: + promo.visit(stmt) # Step 7: Indirection remove_symbol_indirection(sdfg) diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 7fdfbdf737..499e5cc0bd 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar to symbol promotion functionality. """ import dace +from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.transformation.passes import scalar_to_symbol from dace.transformation import transformation as xf, interstate as isxf from dace.transformation.interstate import loop_detection as ld @@ -188,15 +189,21 @@ def testprog6(A: dace.float64[20, 20]): sdfg: dace.SDFG = testprog6.to_sdfg(simplify=False) assert scalar_to_symbol.find_promotable_scalars(sdfg) == {'j'} scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) - sdfg.apply_transformations_repeated(isxf.StateFusion) + sdfg.apply_transformations_repeated([isxf.StateFusion, isxf.BlockFusion]) - # There should be 4 states: - # [empty] --j=A[1, 1]--> [A->MapEntry->Tasklet->MapExit->A] --> [empty] - # \--------------------------------------------/ - assert sdfg.number_of_nodes() == 4 - ctr = collections.Counter(s.number_of_nodes() for s in sdfg) - assert ctr[0] == 3 - assert ctr[5] == 1 + # There should be 2 states: + # [empty] --j=A[1, 1]--> [Conditional] + assert sdfg.number_of_nodes() == 2 + # The conditional should contain one branch, with one state, with a single map from A->A inside of it. + cond = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + cond = n + break + assert cond is not None + assert len(cond.branches) == 1 + assert len(cond.branches[0][1].nodes()) == 1 + assert len(cond.branches[0][1].nodes()[0].nodes()) == 5 # Program should produce correct result A = np.random.rand(20, 20) @@ -235,24 +242,6 @@ def testprog7(A: dace.float64[20, 20]): assert np.allclose(A, expected) -class LoopTester(ld.DetectLoop, xf.MultiStateTransformation): - """ Tester method that sets loop index on a guard state. """ - - def can_be_applied(self, graph, expr_index, sdfg, permissive): - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False - guard = self.loop_guard - if hasattr(guard, '_LOOPINDEX'): - return False - return True - - def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): - guard = self.loop_guard - edge = sdfg.in_edges(guard)[0] - loopindex = next(iter(edge.data.assignments.keys())) - guard._LOOPINDEX = loopindex - - def test_promote_loop(): """ Loop promotion. """ N = dace.symbol('N') @@ -269,7 +258,7 @@ def testprog8(A: dace.float32[20, 20]): assert 'i' in scalar_to_symbol.find_promotable_scalars(sdfg) scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.apply_transformations_repeated(LoopTester) == 1 + assert any(isinstance(n, LoopRegion) for n in sdfg.nodes()) def test_promote_loops(): @@ -294,7 +283,7 @@ def testprog9(A: dace.float32[20, 20]): assert 'k' in scalars scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.apply_transformations_repeated(LoopTester) == 3 + assert any(isinstance(n, LoopRegion) for n in sdfg.nodes()) def test_promote_indirection(): diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 11873fa03d..dd1c3eb518 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -194,11 +194,11 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_4') * N, sp.Symbol('num_execs_0_4'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_4') * N, 2 * sp.Symbol('num_execs_0_4'))), + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), 'break_for_loop': (break_for_loop, (N**2, N)), - 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_7') * N, sp.Symbol('num_execs_0_7'))), 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), 'reduction_library_node': (reduction_library_node, (456, sp.log(456))), 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(N))), From 39db9091d5ee9ebed4c1f1a6051b8163f1f37f7a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 17:33:55 +0200 Subject: [PATCH 061/108] Another one --- dace/codegen/targets/cpu.py | 2 +- dace/frontend/python/newast.py | 4 ++-- dace/transformation/auto/auto_optimize.py | 2 -- dace/transformation/interstate/loop_to_map.py | 7 ++++++- dace/transformation/transformation.py | 3 +-- tests/passes/constant_propagation_test.py | 10 ++++++++++ 6 files changed, 20 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index fad672ffc1..5155c011ab 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -1706,7 +1706,7 @@ def _generate_NestedSDFG( # If the SDFG has a unique name, use it sdfg_label = node.unique_name else: - sdfg_label = "%s_%d_%d_%d" % (node.sdfg.name, sdfg.cfg_id, state_id, dfg.node_id(node)) + sdfg_label = "%s_%d_%d_%d" % (node.sdfg.name, cfg.cfg_id, state_id, dfg.node_id(node)) code_already_generated = False if unique_functions and not inline: diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index cacf15d785..9387c11a7d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4502,7 +4502,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): else: name = "call" call_region = FunctionCallRegion(label=f"{name}_{node.lineno}", arguments=[]) - self.cfg_target.add_node(call_region) + self.cfg_target.add_node(call_region, ensure_unique_name=True) self._on_block_added(call_region) previous_last_cfg_target = self.last_cfg_target previous_target = self.cfg_target @@ -4749,7 +4749,7 @@ def visit_With(self, node: ast.With, is_async=False): else: named_region_name = f"Named Region {node.lineno}" named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo) - self.cfg_target.add_node(named_region) + self.cfg_target.add_node(named_region, ensure_unique_name=True) self._on_block_added(named_region) self._recursive_visit(node.body, "init_named", node.lineno, named_region, unconnected_last_block=False) return diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 7bced3bec9..1c77160338 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -573,8 +573,6 @@ def auto_optimize(sdfg: SDFG, sdfg.apply_transformations_repeated(TrivialMapElimination, validate=validate, validate_all=validate_all) while transformed: sdfg.simplify(validate=False, validate_all=validate_all) - for s in sdfg.cfg_list: - xfh.split_interstate_edges(s) l2ms = sdfg.apply_transformations_repeated((LoopToMap, RefineNestedAccess), validate=False, validate_all=validate_all) diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 5a50c54c45..8d74248270 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -415,7 +415,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): del sdfg.arrays[name] # Add NestedSDFG node - cnode = body.add_nested_sdfg(nsdfg, None, read_set, write_set) + cnode = body.add_nested_sdfg(nsdfg, body, read_set, write_set) if sdfg.parent: for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): if s not in cnode.symbol_mapping: @@ -567,3 +567,8 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): sdfg.remove_symbol(itervar) sdfg.reset_cfg_list() + for n, p in sdfg.all_nodes_recursive(): + if isinstance(n, nodes.NestedSDFG): + n.sdfg.parent = p + n.sdfg.parent_nsdfg_node = n + n.sdfg.parent_sdfg = p.sdfg diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 0cec526a2c..9b109bfcfb 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -811,8 +811,7 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], cfg_id: int = self.subgraph = set(subgraph.graph.node_id(n) for n in subgraph.nodes()) if isinstance(subgraph.graph, SDFGState): - sdfg = subgraph.graph.parent - self.cfg_id = sdfg.cfg_id + self.cfg_id = subgraph.graph.parent_graph.cfg_id self.state_id = subgraph.graph.block_id elif isinstance(subgraph.graph, SDFG): self.cfg_id = subgraph.graph.cfg_id diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index 909e22a2b5..643f397d20 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -40,6 +40,8 @@ def program(A: dace.int64[20]): A[l] = k sdfg = program.to_sdfg() + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + ConstantPropagation().apply_pass(sdfg, {}) assert set(sdfg.symbols.keys()) == {'i'} @@ -63,6 +65,8 @@ def program(a: dace.float64[20]): a[0] = i # Use i - should be const sdfg = program.to_sdfg() + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + ConstantPropagation().apply_pass(sdfg, {}) for node in sdfg.all_control_flow_regions(): if isinstance(node, LoopRegion): @@ -85,6 +89,8 @@ def program(a: dace.float64[20]): a[i] = i # Use i - not const sdfg = program.to_sdfg() + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + ConstantPropagation().apply_pass(sdfg, {}) for node in sdfg.all_control_flow_regions(): if isinstance(node, LoopRegion): @@ -112,6 +118,8 @@ def program(a: dace.float64[20, 20]): a[j, k] = 1 sdfg = program.to_sdfg() + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + ConstantPropagation().apply_pass(sdfg, {}) assert 'j' in sdfg.symbols for node in sdfg.all_control_flow_regions(): @@ -141,6 +149,8 @@ def program(a: dace.float64[20, 20], scal: dace.int32): a[i, j] = 3 sdfg = program.to_sdfg() + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + ConstantPropagation().apply_pass(sdfg, {}) assert len(sdfg.symbols.keys()) == 1 From b7fe24219455e980b3c9f1ca6de06291f9dc986a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 22 Oct 2024 14:43:50 +0200 Subject: [PATCH 062/108] Added a test to highlights the error. Currently we only see the error in `tests/numpy/ufunc_support_test.py::test_ufunc_add_accumulate_simple` if the auto optimizations are enabled. --- .../refine_nested_access_test.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index d9fb9a7392..94c02a88de 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -156,7 +156,108 @@ def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int): assert np.allclose(ref, val) +def _make_rna_read_and_write_set_sdfg(diff_in_out: bool) -> dace.SDFG: + """Generates the SDFG for the `test_rna_read_and_write_sets_*()` tests. + + If `diff_in_out` is `False` then the output is also used as temporary storage + within the nested SDFG. Because of the definition of the read/write sets, + this usage of the temporary storage is not picked up and it is only considered + as write set. + + If `diff_in_out` is true, then a different storage container, which is classified + as output, is used as temporary storage. + + This test was added during [PR#1678](https://github.com/spcl/dace/pull/1678). + """ + + def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: + sdfg = dace.SDFG("inner_sdfg") + state = sdfg.add_state(is_start_block=True) + sdfg.add_array("A", dtype=dace.float64, shape=(2,), transient=False) + sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False) + + A = state.add_access("A") + T1_input = state.add_access("T1") + if diff_in_out: + sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False) + T1_output = state.add_access("T2") + else: + T1_output = state.add_access("T1") + + tsklt = state.add_tasklet( + "comp", + inputs={"__in1": None, "__in2": None}, + outputs={"__out": None}, + code="__out = __in1 + __in2", + ) + + state.add_edge(A, None, tsklt, "__in1", dace.Memlet("A[1]")) + # An alternative would be to write to a different location here. + # Then, the data would be added to the access node. + state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]")) + state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet("T1[0]")) + state.add_edge(tsklt, "__out", T1_output, None, dace.Memlet(T1_output.data + "[1]")) + return sdfg + + sdfg = dace.SDFG("Parent_SDFG") + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array("A", dtype=dace.float64, shape=(2,), transient=False) + sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False) + sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False) + A = state.add_access("A") + T1 = state.add_access("T1") + + nested_sdfg = _make_nested_sdfg(diff_in_out) + + nsdfg = state.add_nested_sdfg( + nested_sdfg, + parent=sdfg, + inputs={"A"}, + outputs={"T2", "T1"} if diff_in_out else {"T1"}, + symbol_mapping={}, + ) + + state.add_edge(A, None, nsdfg, "A", dace.Memlet("A[0:2]")) + state.add_edge(nsdfg, "T1", T1, None, dace.Memlet("T1[0:2]")) + + if diff_in_out: + state.add_edge(nsdfg, "T2", state.add_access("T2"), None, dace.Memlet("T2[0:2]")) + sdfg.validate() + return sdfg + + +def test_rna_read_and_write_sets_doule_use(): + """ + NOTE: Under the current definition of the read/write sets this test will fail. + """ + + # The output is used also as temporary storage. + sdfg = _make_rna_read_and_write_set_sdfg(False) + nb_applied = sdfg.apply_transformations_repeated( + [RefineNestedAccess], + validate=True, + validate_all=True, + ) + assert nb_applied > 0 + + +def test_rna_read_and_write_sets_different_storage(): + + # There is a dedicated temporary storage used. + sdfg = _make_rna_read_and_write_set_sdfg(True) + + nb_applied = sdfg.apply_transformations_repeated( + [RefineNestedAccess], + validate=True, + validate_all=True, + ) + assert nb_applied > 0 + + if __name__ == '__main__': test_refine_dataflow() test_refine_interstate() test_free_symbols_only_by_indices() + test_rna_read_and_write_sets_different_storage() + test_rna_read_and_write_sets_doule_use() From b546b0742104464de08a3ee2840c073be2a52605 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 22 Oct 2024 14:50:55 +0200 Subject: [PATCH 063/108] I now removed the filtering inside the read and write set. This is _just_ a proposal. --- dace/sdfg/state.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 53f8d98491..09e7607d65 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -761,13 +761,15 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) + # NOTE: In a previous version a _single_ read (i.e. leaving Memlet) that was + # fully covered by a single write (i.e. an incoming Memlet) was removed from + # the read set and only the write survived. However, this was never fully + # implemented nor correctly implemented and caused problems. + # So this filtering was removed. + for subgraph in utils.concurrent_subgraphs(self): subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph. subgraph_write_set = collections.defaultdict(list) - # Traverse in topological order, so data that is written before it - # is read is not counted in the read set - # NOTE: Each AccessNode is processed individually. Thus, if an array appears multiple - # times in a path, the individual results are combined, without further processing. for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()): if not isinstance(n, nd.AccessNode): # Read and writes can only be done through access nodes, @@ -805,18 +807,6 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, else out_edge.data.src_subset ) - # If a memlet reads a particular region of data from the access node and there - # exists a memlet at the same access node that writes to the same region, then - # this read is ignored, and not included in the final read set, but only - # accounted fro in the write set. See also note below. - # TODO: Handle the case when multiple disjoint writes would be needed to cover the - # read. E.g. edges write `0:10` and `10:20` but the read happens at `5:15`. - for out_edge in list(out_edges): - for in_edge in in_edges: - if in_subsets[in_edge].covers(out_subsets[out_edge]): - out_edges.remove(out_edge) - break - # Update the read and write sets of the subgraph. if in_edges: subgraph_write_set[n.data].extend(in_subsets.values()) From 8f9e72f568af6e5829d13a6eedc740714c0e0acf Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 22 Oct 2024 13:39:47 +0200 Subject: [PATCH 064/108] And more --- dace/codegen/targets/framecode.py | 15 +- dace/frontend/python/newast.py | 2 +- .../analysis/schedule_tree/sdfg_to_tree.py | 5 +- dace/sdfg/utils.py | 122 ++++++++- .../dataflow/double_buffering.py | 7 +- dace/transformation/dataflow/map_fission.py | 21 +- dace/transformation/helpers.py | 242 +++++------------- .../passes/analysis/analysis.py | 39 --- tests/passes/constant_propagation_test.py | 2 + .../multiple_nested_sdfgs_test.py | 4 +- tests/schedule_tree/schedule_test.py | 12 +- tests/sdfg/state_test.py | 8 +- tests/transformations/nest_subgraph_test.py | 8 +- 13 files changed, 236 insertions(+), 251 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 62cfd03f23..1b58008115 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -21,7 +21,7 @@ from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.transformation.passes.analysis import StateReachability, loop_analysis @@ -1047,21 +1047,20 @@ def generate_code(self, return (generated_header, clean_code, self._dispatcher.used_targets, self._dispatcher.used_environments) -def _get_dominator_and_postdominator(cfg: ControlFlowRegion, accesses: List[Tuple[SDFGState, nodes.AccessNode]]): +def _get_dominator_and_postdominator(sdfg: SDFG, accesses: List[Tuple[SDFGState, nodes.AccessNode]]): """ Gets the closest common dominator and post-dominator for a list of states. Used for determining allocation of data used in branched states. """ - # Get immediate dominators - idom = nx.immediate_dominators(cfg.nx, cfg.start_block) - alldoms = cfg_analysis.all_dominators(cfg, idom) + alldoms: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = collections.defaultdict(lambda: set()) + allpostdoms: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = collections.defaultdict(lambda: set()) + idom: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} + ipostdom: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} + utils.get_control_flow_block_dominators(sdfg, idom, alldoms, ipostdom, allpostdoms) states = [a for a, _ in accesses] data_name = accesses[0][1].data - # Get immediate post-dominators - ipostdom, allpostdoms = utils.postdominators(cfg, return_alldoms=True) - # All dominators and postdominators include the states themselves for state in states: alldoms[state].add(state) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 9387c11a7d..1542365e73 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2561,7 +2561,7 @@ def visit_If(self, node: ast.If): # Add conditional region cond_block = ConditionalBlock(f'if_{node.lineno}') - self.cfg_target.add_node(cond_block) + self.cfg_target.add_node(cond_block, ensure_unique_name=True) self._on_block_added(cond_block) if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 84f36189b3..46ad04f70b 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -652,7 +652,10 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) ############################# # Create initial tree from CFG - cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '') + if sdfg.using_experimental_blocks: + cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '') + else: + cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') # Traverse said tree (also into states) to create the schedule tree def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index f2929e2ac6..e92929c36f 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -11,10 +11,10 @@ import dace.sdfg.nodes from dace.codegen import compiled_sdfg as csdfg from dace.sdfg.graph import MultiConnectorEdge -from dace.sdfg.sdfg import SDFG +from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import (AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, SDFGState, StateSubgraphView, LoopRegion, - ControlFlowRegion) +from dace.sdfg.state import (AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, SDFGState, + StateSubgraphView, LoopRegion, ControlFlowRegion) from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs @@ -1920,3 +1920,119 @@ def get_global_memlet_path_dst(sdfg: SDFG, state: SDFGState, edge: MultiConnecto pedge = pedges[0] return get_global_memlet_path_dst(psdfg, pstate, pedge) return dst + + +def get_control_flow_block_dominators(sdfg: SDFG, + idom: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None, + all_dom: Optional[Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = None, + ipostdom: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None, + all_postdom: Optional[Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = None): + """ + Find the dominator and postdominator relationship between control flow blocks of an SDFG. + This transitively computes the domination relationship across control flow regions, as if the SDFG were to be + inlined entirely. + + :param idom: A dictionary in which to store immediate dominator relationships. Not computed if None. + :param all_dom: A dictionary in which to store all dominator relationships. Not computed if None. + :param ipostdom: A dictionary in which to store immediate postdominator relationships. Not computed if None. + :param all_postdom: A dictionary in which to all postdominator relationships. Not computed if None. + """ + # Avoid cyclic import + from dace.sdfg.analysis import cfg as cfg_analysis + + if idom is not None or all_dom is not None: + added_sinks: Dict[AbstractControlFlowRegion, SDFGState] = {} + if idom is None: + idom = {} + for cfg in sdfg.all_control_flow_regions(parent_first=True): + if isinstance(cfg, ConditionalBlock): + continue + sinks = cfg.sink_nodes() + if len(sinks) > 1: + added_sinks[cfg] = cfg.add_state() + for s in sinks: + cfg.add_edge(s, added_sinks[cfg], InterstateEdge()) + idom.update(nx.immediate_dominators(cfg.nx, cfg.start_block)) + # Compute the transitive relationship of immediate dominators: + # - For every start state in a control flow region, the immediate dominator is the immediate dominator of the + # parent control flow region. + # - If the immediate dominator is a conditional or a loop, change the immediate dominator to be the immediate + # dominator of that loop or conditional. + # - If the immediate dominator is any other control flow region, change the immediate dominator to be the + # immediate dominator of that region's end / exit - or a virtual one if no single one exists. + for k, _ in idom.items(): + if k.parent_graph is not sdfg and k is k.parent_graph.start_block: + next_dom = idom[k.parent_graph] + while next_dom.parent_graph is not sdfg and next_dom is next_dom.parent_graph.start_block: + next_dom = idom[next_dom.parent_graph] + idom[k] = next_dom + changed = True + while changed: + changed = False + for k, v in idom.items(): + if isinstance(v, AbstractControlFlowRegion): + if isinstance(v, (LoopRegion, ConditionalBlock)): + idom[k] = idom[v] + else: + if v in added_sinks: + idom[k] = idom[added_sinks[v]] + else: + idom[k] = v.sink_nodes()[0] + if idom[k] is not v: + changed = True + + for cf, v in added_sinks.items(): + cf.remove_node(v) + + if all_dom is not None: + all_dom.update(cfg_analysis.all_dominators(sdfg, idom)) + + if ipostdom is not None or all_postdom is not None: + added_sinks: Dict[AbstractControlFlowRegion, SDFGState] = {} + sinks_per_cfg: Dict[AbstractControlFlowRegion, ControlFlowBlock] = {} + if ipostdom is None: + ipostdom = {} + + for cfg in sdfg.all_control_flow_regions(parent_first=True): + if isinstance(cfg, ConditionalBlock): + continue + # Get immediate post-dominators + sink_nodes = cfg.sink_nodes() + if len(sink_nodes) > 1: + sink = cfg.add_state() + added_sinks[cfg] = sink + sinks_per_cfg[cfg] = sink + for snode in sink_nodes: + cfg.add_edge(snode, sink, dace.InterstateEdge()) + elif len(sink_nodes) == 0: + return None + else: + sink = sink_nodes[0] + sinks_per_cfg[cfg] = sink + ipostdom.update(nx.immediate_dominators(cfg._nx.reverse(), sink)) + + # Compute the transitive relationship of immediate postdominators, similar to how it works for immediate + # dominators, but inverse. + for k, _ in ipostdom.items(): + if k.parent_graph is not sdfg and k is sinks_per_cfg[k.parent_graph]: + next_pdom = ipostdom[k.parent_graph] + while next_pdom.parent_graph is not sdfg and next_pdom is sinks_per_cfg[next_pdom.parent_graph]: + next_pdom = ipostdom[next_pdom.parent_graph] + ipostdom[k] = next_pdom + changed = True + while changed: + changed = False + for k, v in ipostdom.items(): + if isinstance(v, AbstractControlFlowRegion): + if isinstance(v, (LoopRegion, ConditionalBlock)): + ipostdom[k] = ipostdom[v] + else: + ipostdom[k] = v.start_block + if ipostdom[k] is not v: + changed = True + + for cf, v in added_sinks.items(): + cf.remove_node(v) + + if all_postdom is not None: + all_postdom.update(cfg_analysis.all_dominators(sdfg, ipostdom)) diff --git a/dace/transformation/dataflow/double_buffering.py b/dace/transformation/dataflow/double_buffering.py index bb42aa57ac..e0bc76818d 100644 --- a/dace/transformation/dataflow/double_buffering.py +++ b/dace/transformation/dataflow/double_buffering.py @@ -127,8 +127,8 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add initial reads to initial nested state - initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state - initial_state.label = '%s_init' % map_entry.map.label + loop_block = nsdfg_node.sdfg.start_block + initial_state = nsdfg_node.sdfg.add_state_before(loop_block, '%s_init' % map_entry.map.label) for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src @@ -151,8 +151,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add the main state's contents to the last state, modifying # memlets appropriately. - final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] - final_state.label = '%s_final_computation' % map_entry.map.label + final_state = nsdfg_node.sdfg.add_state_after(loop_block, '%s_final_computation' % map_entry.map.label) dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index f0e8499c92..d9798fe81a 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -5,15 +5,17 @@ from collections import defaultdict from dace import sdfg as sd, memlet as mm, subsets, data as dt from dace.codegen import control_flow as cf +from dace.properties import CodeBlock from dace.sdfg import nodes, graph as gr from dace.sdfg import utils as sdutil from dace.sdfg.propagation import propagate_memlets_state, propagate_subset +from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation, helpers from typing import List, Optional, Tuple -@transformation.single_level_sdfg_only +@transformation.experimental_cfg_block_compatible class MapFission(transformation.SingleStateTransformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, @@ -123,12 +125,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Get NestedSDFG control flow components nsdfg_node.sdfg.reset_cfg_list() - cf_comp = helpers.find_sdfg_control_flow(nsdfg_node.sdfg) - if len(cf_comp) == 1: - child = list(cf_comp.values())[0][1] - conditions = [] - if isinstance(child, (cf.ForScope, cf.WhileScope, cf.IfScope)): - conditions.append(child.condition if isinstance(child, (cf.ForScope, cf.IfScope)) else child.test) + if len(nsdfg_node.sdfg.nodes()) == 1: + child = nsdfg_node.sdfg.nodes()[0] + conditions: List[CodeBlock] = [] + if isinstance(child, LoopRegion): + conditions.append(child.loop_condition) + elif isinstance(child, ConditionalBlock): + for c, _ in child.branches: + if c is not None: + conditions.append(c) for cond in conditions: if any(p in cond.get_free_symbols() for p in map_node.map.params): return False @@ -138,7 +143,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False if any(p in cond.get_free_symbols() for p in map_node.map.params): return False - helpers.nest_sdfg_control_flow(nsdfg_node.sdfg, cf_comp) + helpers.nest_sdfg_control_flow(nsdfg_node.sdfg) subgraphs = list(nsdfg_node.sdfg.nodes()) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index c6a701bd48..4c6631a275 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -4,7 +4,8 @@ import itertools from networkx import MultiDiGraph -from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.properties import CodeBlock +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnBlock from dace.subsets import Range, Subset, union import dace.subsets as subsets from typing import Dict, Iterable, List, Optional, Tuple, Set, Union @@ -32,7 +33,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Nest states blocks: List[ControlFlowBlock] = subgraph.nodes() return_state = None - if len(blocks) > 1: + if len(blocks) > 1 or isinstance(blocks[0], AbstractControlFlowRegion): # Avoid cyclic imports from dace.transformation.passes.analysis import loop_analysis @@ -55,6 +56,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS is_edges: List[Edge[InterstateEdge]] = [] for b in blocks: if isinstance(b, AbstractControlFlowRegion): + all_blocks.append(b) for nb in b.all_control_flow_blocks(): all_blocks.append(nb) for e in b.all_interstate_edges(): @@ -66,6 +68,13 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS for dst in blocks: for edge in graph.edges_between(src, dst): is_edges.append(edge) + return_blocks: Set[ReturnBlock] = set([b for b in all_blocks if isinstance(b, ReturnBlock)]) + if len(return_blocks) > 0: + did_return_inner = '_did_ret_from_nsdfg' + did_return_inner = sdfg._find_new_name(did_return_inner) + sdfg.add_scalar(did_return_inner, dtypes.int32, transient=True) + else: + did_return_inner = None # Find read/write sets read_set, write_set = set(), set() @@ -90,6 +99,17 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS for s in edge.data.free_symbols: if s in sdfg.arrays: read_set.add(s) + for blk in all_blocks: + if isinstance(blk, ConditionalBlock): + for c, _ in blk.branches: + if c is not None: + for s in c.get_free_symbols(): + if s in sdfg.arrays: + read_set.add(s) + elif isinstance(blk, LoopRegion): + for s in blk.loop_condition.get_free_symbols(): + if s in sdfg.arrays: + read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set @@ -113,13 +133,14 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} # Find defined subgraph symbols + sdfg_symbols = sdfg.used_symbols(True) defined_symbols = set() strictly_defined_symbols = set() for e in is_edges: defined_symbols.update(set(e.data.assignments.keys())) for k, v in e.data.assignments.items(): try: - if k not in sdfg.symbols and k not in {str(a) for a in symbolic.pystr_to_symbolic(v).args}: + if k not in sdfg_symbols and k not in {str(a) for a in symbolic.pystr_to_symbolic(v).args}: strictly_defined_symbols.add(k) except AttributeError: # `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args` @@ -127,7 +148,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS for b in all_blocks: if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '': defined_symbols.update(b.loop_variable) - if b.loop_variable not in sdfg.symbols: + if b.loop_variable not in sdfg_symbols: if b.init_statement: init_assignment = loop_analysis.get_init_assignment(b) if b.loop_variable not in {str(s) for s in symbolic.pystr_to_symbolic(init_assignment).args}: @@ -136,16 +157,51 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS strictly_defined_symbols.add(b.loop_variable) return_state = new_state = graph.add_state('nested_sdfg_parent') + + # If there is a return that is being nested in, a conditional return is added right after the new nested SDFG + # which will be taken if the inner, nested return was hit. + ret_cond = None + if len(return_blocks) > 0: + ret_cond = ConditionalBlock('return_' + sdfg.label + '_from_nested', sdfg, graph) + graph.add_node(ret_cond, ensure_unique_name=True) + ret_branch = ControlFlowRegion('return_' + sdfg.label + '_from_nested_body', sdfg, ret_cond) + ret_block = ReturnBlock('return', sdfg, ret_branch) + ret_branch.add_node(ret_block) + ret_cond.add_branch(CodeBlock(did_return_inner), ret_branch) + nsdfg = SDFG("nested_sdfg", constants=sdfg.constants_prop, parent=new_state) nsdfg.add_node(source_node, is_start_state=True) nsdfg.add_nodes_from([s for s in blocks if s is not source_node]) for e in subgraph.edges(): nsdfg.add_edge(e.src, e.dst, e.data) - for e in graph.in_edges(source_node): - graph.add_edge(e.src, new_state, e.data) - for e in graph.out_edges(sink_node): - graph.add_edge(new_state, e.dst, e.data) + # Annotate any transitions to return blocks in the inner, nested SDFG by first setting the added transient + # scalar to 1 / true to detect that the inner SDFG returned. + if len(return_blocks) > 0: + for blk in nsdfg.all_control_flow_blocks(): + if blk in return_blocks: + pre_state = blk.parent_graph.add_state_before(blk) + did_ret_tasklet = pre_state.add_tasklet('__did_ret_set', {}, {'out'}, 'out = 1') + did_ret_access = pre_state.add_access(did_return_inner) + pre_state.add_edge(did_ret_tasklet, 'out', did_ret_access, None, Memlet(did_return_inner + '[0]')) + write_set.add(did_return_inner) + + if ret_cond is not None: + pre_state = graph.add_state('before_nested_sdfg_parent') + for e in graph.in_edges(source_node): + graph.add_edge(e.src, pre_state, e.data) + did_ret_tasklet = pre_state.add_tasklet('__did_ret_init', {}, {'out'}, 'out = 0') + did_ret_access = pre_state.add_access(did_return_inner) + pre_state.add_edge(did_ret_tasklet, 'out', did_ret_access, None, Memlet(did_return_inner + '[0]')) + graph.add_edge(pre_state, new_state, InterstateEdge()) + graph.add_edge(new_state, ret_cond, InterstateEdge()) + for e in graph.out_edges(sink_node): + graph.add_edge(ret_cond, e.dst, e.data) + else: + for e in graph.in_edges(source_node): + graph.add_edge(e.src, new_state, e.data) + for e in graph.out_edges(sink_node): + graph.add_edge(new_state, e.dst, e.data) graph.remove_nodes_from(blocks) @@ -185,7 +241,8 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Add NestedSDFG node fsymbols = sdfg.symbols.keys() | nsdfg.free_symbols - fsymbols.update(defined_symbols - strictly_defined_symbols) + fsymbols.update(defined_symbols) + fsymbols = fsymbols - strictly_defined_symbols mapping = {s: s for s in fsymbols} cnode = new_state.add_nested_sdfg(nsdfg, None, read_set, write_set, mapping) for s in strictly_defined_symbols: @@ -215,173 +272,16 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS return return_state -def _copy_state(sdfg: SDFG, - state: SDFGState, - before: bool = True, - states: Optional[Set[SDFGState]] = None) -> SDFGState: - """ - Duplicates a state, placing the copy before or after (see param before) the original and redirecting a subset of its - edges (see param state). The state is expected to be a scope's source or sink state and this method facilitates the - nesting of SDFG subgraphs where the state may be part of multiple scopes. - - :param state: The SDFGState to copy. - :param before: True if the copy should be placed before the original. - :param states: A collection of SDFGStates that should be considered for edge redirection. - :return: The SDFGState copy. - """ - - state_copy = copy.deepcopy(state) - state_copy._label += '_copy' - state_copy.parent = sdfg - sdfg.add_node(state_copy) - - in_conditions = [] - for e in sdfg.in_edges(state): - if states and e.src not in states: - continue - sdfg.add_edge(e.src, state_copy, e.data) - sdfg.remove_edge(e) - if not e.data.is_unconditional(): - in_conditions.append(e.data.condition.as_string) - - out_conditions = [] - for e in sdfg.out_edges(state): - if states and e.dst not in states: - continue - sdfg.add_edge(state_copy, e.dst, e.data) - sdfg.remove_edge(e) - if not e.data.is_unconditional(): - out_conditions.append(e.data.condition.as_string) - - if before: - condition = None - if in_conditions: - condition = 'or'.join([f"({c})" for c in in_conditions]) - sdfg.add_edge(state_copy, state, InterstateEdge(condition=condition)) - else: - condition = None - # NOTE: The following should be unecessary for preserving program semantics. Therefore we comment it out to - # avoid the overhead of evaluating the condition. - # if out_conditions: - # condition = 'or'.join([f"({c})" for c in out_conditions]) - sdfg.add_edge(state, state_copy, InterstateEdge(condition=condition)) - - return state_copy - - -def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, - Tuple[Set[ControlFlowBlock], ControlFlowBlock]]: - """ - Partitions a CFG to subgraphs that can be nested independently of each other. The method does not nest the - subgraphs but alters the graph; (1) interstate edges are split, (2) scope source/sink nodes that belong to multiple - scopes are duplicated (see _copy_state). - - :param cfg: The graph to be partitioned. - :return: The found subgraphs in the form of a dictionary where the keys are the start block of the subgraphs and the - values are the sets of ControlFlowBlocks contained withing each subgraph. - """ - - split_interstate_edges(cfg) - - # Create a unique sink block to avoid issues with finding control flow. - sink_nodes = cfg.sink_nodes() - if len(sink_nodes) > 1: - new_sink = cfg.add_state('common_sink') - for s in sink_nodes: - cfg.add_edge(s, new_sink, InterstateEdge()) - - ipostdom = utils.postdominators(cfg) - if cfg.root_sdfg.using_experimental_blocks: - cft = cf.structured_control_flow_tree_with_regions(cfg, None) - else: - cft = cf.structured_control_flow_tree(cfg, None) - - # Iterate over the graph's control flow scopes and create for each a subraph. These subgraphs must be disjoint, - # so we duplicate blocks that appear in more than one scopes (guards and exits of loops and conditionals). - components: Dict[ControlFlowBlock, Tuple[Set[ControlFlowBlock], ControlFlowBlock]] = {} - visited: Dict[ControlFlowBlock, bool] = {} # Block -> True if block in Scope (non-SingleState) - for i, child in enumerate(cft.children): - if isinstance(child, cf.BasicCFBlock): - if child.state in visited: - continue - components[child.state] = (set([child.state]), child) - visited[child.state] = False - elif isinstance(child, (cf.ForScope, cf.WhileScope)): - guard = child.guard - fexit = None - condition = child.condition if isinstance(child, cf.ForScope) else child.test - for e in cfg.out_edges(guard): - if e.data.condition != condition: - fexit = e.dst - break - if fexit is None: - raise ValueError("Cannot find for-scope's exit states.") - - states = set(utils.dfs_conditional(cfg, [guard], lambda p, _: p is not fexit)) - - if guard in visited: - if visited[guard]: - guard_copy = _copy_state(cfg, guard, False, states) - guard.remove_nodes_from(guard.nodes()) - states.remove(guard) - states.add(guard_copy) - guard = guard_copy - else: - del components[guard] - del visited[guard] - - if not (i == len(cft.children) - 2 and isinstance(cft.children[i + 1], cf.BasicCFBlock) - and cft.children[i + 1].state is fexit): - fexit_copy = _copy_state(cfg, fexit, True, states) - fexit.remove_nodes_from(fexit.nodes()) - states.remove(fexit) - states.add(fexit_copy) - - components[guard] = (states, child) - visited.update({s: True for s in states}) - elif isinstance(child, (cf.IfScope, cf.IfElseChain)): - guard = child.branch_block - ifexit = ipostdom[guard] - - states = set(utils.dfs_conditional(cfg, [guard], lambda p, _: p is not ifexit)) - - if guard in visited: - if visited[guard]: - guard_copy = _copy_state(cfg, guard, False, states) - guard.remove_nodes_from(guard.nodes()) - states.remove(guard) - states.add(guard_copy) - guard = guard_copy - else: - del components[guard] - del visited[guard] - - if not (i == len(cft.children) - 2 and isinstance(cft.children[i + 1], cf.BasicCFBlock) - and cft.children[i + 1].state is ifexit): - ifexit_copy = _copy_state(cfg, ifexit, True, states) - ifexit.remove_nodes_from(ifexit.nodes()) - states.remove(ifexit) - states.add(ifexit_copy) - - components[guard] = (states, child) - visited.update({s: True for s in states}) - else: - raise ValueError(f"Unsupported control flow class {type(child)}") - - return components - - -def nest_sdfg_control_flow(sdfg: SDFG, components=None): +def nest_sdfg_control_flow(sdfg: SDFG): """ Partitions the SDFG to subgraphs and nests them. :param sdfg: The SDFG to be partitioned. - :param components: An existing partition of the SDFG. """ - regions = list(sdfg.all_control_flow_regions()) - for region in regions: - nest_sdfg_subgraph(region.sdfg, SubgraphView(region.sdfg, [region]), region) - sdfg.reset_cfg_list() + for nd in sdfg.nodes(): + if isinstance(nd, AbstractControlFlowRegion): + nest_sdfg_subgraph(sdfg, SubgraphView(sdfg, [nd])) + sdfg.reset_cfg_list() def nest_state_subgraph(sdfg: SDFG, diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 720b0b8b5b..e673337a35 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -19,45 +19,6 @@ SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], ControlFlowBlock]]]] -@properties.make_properties -@transformation.experimental_cfg_block_compatible -class InterstateEdgeReachability(ppl.Pass): - """ - Evaluates which interstate edges can be executed after each control flow block. - """ - - CATEGORY: str = 'Analysis' - - def modifies(self) -> ppl.Modifies: - return ppl.Modifies.Nothing - - def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.CFG - - def depends_on(self): - return {ControlFlowBlockReachability} - - def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: - """ - :return: A dictionary mapping each state to its other reachable states. - """ - # Ensure control flow block reachability is run if not run within a pipeline. - if pipeline_res is None or not ControlFlowBlockReachability.__name__ in pipeline_res: - cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) - else: - cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] - reachable: Dict[int, Dict[ControlFlowBlock, Set[Edge[InterstateEdge]]]] = {} - for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) - for state in sdfg.states(): - for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: - if isinstance(reached, SDFGState): - result[state].add(reached) - reachable[sdfg.cfg_id] = result - return reachable - - @properties.make_properties @transformation.experimental_cfg_block_compatible class StateReachability(ppl.Pass): diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index 643f397d20..48ae6f5b91 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -20,6 +20,8 @@ def program(A: dace.float64[20]): A[:] = cval + 4 sdfg = program.to_sdfg() + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + ConstantPropagation().apply_pass(sdfg, {}) assert len(sdfg.symbols) == 0 for e in sdfg.edges(): diff --git a/tests/python_frontend/multiple_nested_sdfgs_test.py b/tests/python_frontend/multiple_nested_sdfgs_test.py index fc1d9f852b..722342dbfe 100644 --- a/tests/python_frontend/multiple_nested_sdfgs_test.py +++ b/tests/python_frontend/multiple_nested_sdfgs_test.py @@ -68,8 +68,8 @@ def multiple_nested_sdfgs(input: dace.float32[2, 2], output: dace.float32[2, 2]) sdfg = multiple_nested_sdfgs.to_sdfg(simplify=False) state = None - for node in sdfg.nodes(): - if re.fullmatch(r"out_tmp_div_sum_\d+_call.*", node.label): + for node in sdfg.states(): + if re.fullmatch(r"call_out_tmp_div_sum_\d+.*", node.label): assert state is None, "Two states match the regex, cannot decide which one should be used" state = node assert state is not None diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 542ff425dc..19f5e19cc6 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -5,7 +5,9 @@ from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree import numpy as np -from dace.sdfg.utils import inline_control_flow_regions +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import LoopRegion +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising def test_for_in_map_in_for(): @@ -29,14 +31,14 @@ def matmul(A: dace.float32[10, 10], B: dace.float32[10, 10], C: dace.float32[10, assert len(stree.children) == 1 # for fornode = stree.children[0] - assert isinstance(fornode, tn.ForScope) + assert isinstance(fornode, tn.GeneralLoopScope) assert len(fornode.children) == 1 # map mapnode = fornode.children[0] assert isinstance(mapnode, tn.MapScope) assert len(mapnode.children) == 2 # copy, for copynode, fornode = mapnode.children assert isinstance(copynode, tn.CopyNode) - assert isinstance(fornode, tn.ForScope) + assert isinstance(fornode, tn.GeneralLoopScope) assert len(fornode.children) == 1 # tasklet tasklet = fornode.children[0] assert isinstance(tasklet, tn.TaskletNode) @@ -82,7 +84,7 @@ def main(a: dace.float64[20, 10]): assert len(stree.children) == 4 offsets = ['', '5', '10', '15'] for fornode, offset in zip(stree.children, offsets): - assert isinstance(fornode, tn.ForScope) + assert isinstance(fornode, tn.GeneralLoopScope) assert len(fornode.children) == 1 # map mapnode = fornode.children[0] assert isinstance(mapnode, tn.MapScope) @@ -130,7 +132,7 @@ def main(a: dace.float64[20, 10]): sdfg = main.to_sdfg() stree = as_schedule_tree(sdfg) - assert isinstance(stree.children[0], tn.NView) + assert any(isinstance(v, tn.NView) for v in stree.children) def test_irreducible_sub_sdfg(): diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 7ba43ac4c0..57313c38f6 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -1,6 +1,6 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from copy import deepcopy import dace -from dace.transformation.helpers import find_sdfg_control_flow def test_read_write_set(): @@ -54,8 +54,8 @@ def double_loop(arr: dace.float32[N]): arr[i] *= 2 sdfg = double_loop.to_sdfg() - find_sdfg_control_flow(sdfg) - sdfg.validate() + copied_sdfg = deepcopy(sdfg) + copied_sdfg.validate() def test_add_mapped_tasklet(): diff --git a/tests/transformations/nest_subgraph_test.py b/tests/transformations/nest_subgraph_test.py index 623b029c3a..a9ed62cbdf 100644 --- a/tests/transformations/nest_subgraph_test.py +++ b/tests/transformations/nest_subgraph_test.py @@ -71,7 +71,7 @@ def symbolic_return(): cft = cf.structured_control_flow_tree(sdfg, None) for_scope = None for i, child in enumerate(cft.children): - if isinstance(child, (cf.ForScope, cf.WhileScope)): + if isinstance(child, (cf.GeneralLoopScope)): for_scope = child break assert for_scope @@ -80,11 +80,9 @@ def symbolic_return(): exit_scope = cft.children[i+1] assert isinstance(exit_scope, cf.BasicCFBlock) - guard = for_scope.guard - fexit = exit_scope.first_block - states = list(utils.dfs_conditional(sdfg, [guard], lambda p, _: p is not fexit)) + states = for_scope.loop.nodes() - nest_sdfg_subgraph(sdfg, SubgraphView(sdfg, states), start=guard) + nest_sdfg_subgraph(sdfg, SubgraphView(for_scope.loop, states)) result = sdfg() val = result[1][0] From 2c4c17b6963ac96ed92c1f19174ec3b99d9d80ee Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 22 Oct 2024 18:21:45 +0200 Subject: [PATCH 065/108] Fix inline multistate --- .../interstate/multistate_inline.py | 20 ++++++++++++++++++- tests/codegen/allocation_lifetime_test.py | 2 +- tests/constant_array_test.py | 2 +- tests/fortran/array_test.py | 14 ++----------- tests/inlining_test.py | 2 +- tests/multistate_init_test.py | 2 +- tests/passes/dead_code_elimination_test.py | 2 +- tests/passes/scalar_fission_test.py | 2 +- tests/transformations/gpu_transform_test.py | 2 +- .../refine_nested_access_test.py | 2 +- 10 files changed, 29 insertions(+), 21 deletions(-) diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 34a31d52b3..4563c4ab92 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -14,7 +14,9 @@ from dace.transformation import transformation, helpers from dace.properties import make_properties from dace import data -from dace.sdfg.state import LoopRegion, StateSubgraphView +from dace.sdfg.state import LoopRegion, ReturnBlock, StateSubgraphView +from dace.transformation.passes.fusion_inline import InlineControlFlowRegions +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising @make_properties @@ -135,6 +137,22 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg + # If the nested SDFG contains returns, ensure they are inlined first. + has_return = False + for blk in nsdfg.all_control_flow_blocks(): + if isinstance(blk, ReturnBlock): + has_return = True + if has_return: + inline_pass = InlineControlFlowRegions() + inline_pass.no_inline_conditional = False + inline_pass.no_inline_named_regions = False + inline_pass.no_inline_function_call_regions = False + inline_pass.no_inline_loops = False + inline_pass.apply_pass(nsdfg, {}) + # After inlining, try to lift out control flow again, essentially preserving all control flow that can be + # preserved while removing the return blocks. + ControlFlowRaising().apply_pass(nsdfg, {}) + if nsdfg_node.schedule != dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types(nsdfg, [nsdfg_node.schedule]) diff --git a/tests/codegen/allocation_lifetime_test.py b/tests/codegen/allocation_lifetime_test.py index e87e6bf109..a54704b05b 100644 --- a/tests/codegen/allocation_lifetime_test.py +++ b/tests/codegen/allocation_lifetime_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. """ Tests different allocation lifetimes. """ import pytest diff --git a/tests/constant_array_test.py b/tests/constant_array_test.py index 95e92b169f..d8067f524e 100644 --- a/tests/constant_array_test.py +++ b/tests/constant_array_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from __future__ import print_function import argparse diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index 3283d2e37f..d5b8c5d669 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -1,22 +1,12 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.common.readfortran import FortranStringReader -from fparser.common.readfortran import FortranFileReader -from fparser.two.parser import ParserFactory -import sys, os import numpy as np -import pytest -from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace import dtypes, symbolic from dace.frontend.fortran import fortran_parser -from fparser.two.symbol_table import SymbolTable from dace.sdfg import utils as sdutil from dace.sdfg.nodes import AccessNode -import dace.frontend.fortran.ast_components as ast_components -import dace.frontend.fortran.ast_transforms as ast_transforms -import dace.frontend.fortran.ast_utils as ast_utils -import dace.frontend.fortran.ast_internal_classes as ast_internal_classes from dace.sdfg.state import LoopRegion diff --git a/tests/inlining_test.py b/tests/inlining_test.py index 9368d5bbce..30476362c1 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace from dace.sdfg.state import FunctionCallRegion, NamedRegion from dace.transformation.interstate import InlineSDFG, StateFusion diff --git a/tests/multistate_init_test.py b/tests/multistate_init_test.py index 8efc68d260..3359ce3e56 100644 --- a/tests/multistate_init_test.py +++ b/tests/multistate_init_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace import numpy as np diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index bf40ff4409..caebd20323 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various tests for dead code elimination passes. """ import numpy as np diff --git a/tests/passes/scalar_fission_test.py b/tests/passes/scalar_fission_test.py index eeb959a926..f8c59b8f4d 100644 --- a/tests/passes/scalar_fission_test.py +++ b/tests/passes/scalar_fission_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar fission pass. """ import pytest diff --git a/tests/transformations/gpu_transform_test.py b/tests/transformations/gpu_transform_test.py index 2099077d81..f6d299e630 100644 --- a/tests/transformations/gpu_transform_test.py +++ b/tests/transformations/gpu_transform_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ Unit tests for the GPU to-device transformation. """ import dace diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index 4c33ece899..d9fb9a7392 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Tests for the RefineNestedAccess transformation. """ import dace import numpy as np From ae205909ca3a78cfea7203dc037461f30d08ceb8 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 09:32:00 +0200 Subject: [PATCH 066/108] Fixed `state_test.py::test_read_and_write_set_filter`. Because of the removed filtering, `B` is now part of the read set as well. --- tests/sdfg/state_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 33e02088a4..7a6894c8ac 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -87,6 +87,7 @@ def test_read_and_write_set_filter(): expected_reads = { "A": [sbs.Range.from_string("0, 0")], + "B": [sbs.Range.from_string("0")], } expected_writes = { "B": [sbs.Range.from_string("0")], From db211fa169732e87c899215ebccbbbd3bd583c03 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 09:35:13 +0200 Subject: [PATCH 067/108] Fixed the `state_test.py::test_read_write_set` test. Because of the removed filtering `B` is no longer removed from the read set. However, because now the test has become quite useless. We also see that the test was specifically constructed in such a way that the filtering applies. Otherwise it would never triggered. --- tests/sdfg/state_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 7a6894c8ac..efc11cc844 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -20,7 +20,9 @@ def test_read_write_set(): state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet('B[2]')) state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet('C[2]')) - assert 'B' not in state.read_and_write_sets()[0] + read_set, write_set = state.read_and_write_sets() + assert {'B', 'A'} == read_set + assert {'C', 'B'} == write_set def test_read_write_set_y_formation(): From 570437b18ac3a780c0150fbe923babd61f354cb4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 09:40:45 +0200 Subject: [PATCH 068/108] Fixed the `state_test.py::test_read_write_set_y_formation` test. Because of the removed filtering `B` is no longer removed from the read set. However, because now the test has become quite useless. We also see that the test was specifically constructed in such a way that the filtering applies. Otherwise it would never triggered. --- tests/sdfg/state_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index efc11cc844..4bde3788e0 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -44,7 +44,9 @@ def test_read_write_set_y_formation(): state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet(data='B', subset='0')) state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet(data='C', subset='0')) - assert 'B' not in state.read_and_write_sets()[0] + read_set, write_set = state.read_and_write_sets() + assert {'B', 'A'} == read_set + assert {'C', 'B'} == write_set def test_deepcopy_state(): From 6806dc11b8c7192fd817b66c4a84bcf7a363364c Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 09:47:57 +0200 Subject: [PATCH 069/108] Fix cyclic dependency --- dace/transformation/interstate/multistate_inline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 4563c4ab92..975b08b507 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -15,8 +15,6 @@ from dace.properties import make_properties from dace import data from dace.sdfg.state import LoopRegion, ReturnBlock, StateSubgraphView -from dace.transformation.passes.fusion_inline import InlineControlFlowRegions -from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising @make_properties @@ -143,6 +141,10 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): if isinstance(blk, ReturnBlock): has_return = True if has_return: + # Avoid cyclic imports + from dace.transformation.passes.fusion_inline import InlineControlFlowRegions + from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising + inline_pass = InlineControlFlowRegions() inline_pass.no_inline_conditional = False inline_pass.no_inline_named_regions = False From ab255e8b9c89254db17de7758639e954bf9b2dba Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 10:03:28 +0200 Subject: [PATCH 070/108] Fixes to codegen and data instrumentation --- dace/codegen/targets/cpu.py | 12 ++++++------ tests/codegen/allocation_lifetime_test.py | 2 +- tests/codegen/data_instrumentation_test.py | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 5155c011ab..6d3d82641e 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -850,7 +850,7 @@ def _emit_copy( # Instrumentation: Pre-copy for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_copy_begin(sdfg, state_dfg, src_node, dst_node, edge, stream, None, copy_shape, + instr.on_copy_begin(sdfg, cfg, state_dfg, src_node, dst_node, edge, stream, None, copy_shape, src_strides, dst_strides) nc = True @@ -912,7 +912,7 @@ def _emit_copy( # Instrumentation: Post-copy for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_copy_end(sdfg, state_dfg, src_node, dst_node, edge, stream, None) + instr.on_copy_end(sdfg, cfg, state_dfg, src_node, dst_node, edge, stream, None) ############################################################# ########################################################################### @@ -1502,7 +1502,7 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra # Instrumentation: Pre-tasklet instr = self._dispatcher.instrumentation[node.instrument] if instr is not None: - instr.on_node_begin(sdfg, state_dfg, node, outer_stream_begin, inner_stream, function_stream) + instr.on_node_begin(sdfg, cfg, state_dfg, node, outer_stream_begin, inner_stream, function_stream) inner_stream.write("\n ///////////////////\n", cfg, state_id, node) @@ -1531,7 +1531,7 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra # Instrumentation: Post-tasklet if instr is not None: - instr.on_node_end(sdfg, state_dfg, node, outer_stream_end, inner_stream, function_stream) + instr.on_node_end(sdfg, cfg, state_dfg, node, outer_stream_end, inner_stream, function_stream) callsite_stream.write(outer_stream_begin.getvalue(), cfg, state_id, node) callsite_stream.write('{', cfg, state_id, node) @@ -2151,7 +2151,7 @@ def _generate_AccessNode(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub # Instrumentation: Pre-node instr = self._dispatcher.instrumentation[node.instrument] if instr is not None: - instr.on_node_begin(sdfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) + instr.on_node_begin(sdfg, cfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) sdict = state_dfg.scope_dict() for edge in state_dfg.in_edges(node): @@ -2194,7 +2194,7 @@ def _generate_AccessNode(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub # Instrumentation: Post-node if instr is not None: - instr.on_node_end(sdfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) + instr.on_node_end(sdfg, cfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) # Methods for subclasses to override diff --git a/tests/codegen/allocation_lifetime_test.py b/tests/codegen/allocation_lifetime_test.py index a54704b05b..9a68cd2140 100644 --- a/tests/codegen/allocation_lifetime_test.py +++ b/tests/codegen/allocation_lifetime_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ Tests different allocation lifetimes. """ import pytest diff --git a/tests/codegen/data_instrumentation_test.py b/tests/codegen/data_instrumentation_test.py index b254a204b5..aef9c83df3 100644 --- a/tests/codegen/data_instrumentation_test.py +++ b/tests/codegen/data_instrumentation_test.py @@ -317,12 +317,9 @@ def dinstr(A: dace.float64[20]): assert len(dreport.keys()) == 1 assert 'i' in dreport.keys() - assert len(dreport['i']) == 22 - desired = list(range(1, 19)) - s_idx = dreport['i'].index(1) - e_idx = dreport['i'].index(18) - assert np.allclose(dreport['i'][s_idx:e_idx+1], desired) - assert 19 in dreport['i'] + assert len(dreport['i']) == 19 + desired = list(range(0, 19)) + assert np.allclose(dreport['i'], desired) @pytest.mark.datainstrument @@ -356,7 +353,10 @@ def dinstr(A: dace.float64[20]): for i in range(j): A[i] = 0 - sdfg = dinstr.to_sdfg(simplify=True) + # Simplification is turned off to avoid killing the initial start state, since symbol instrumentation can for now + # only be triggered on SDFG states. + # TODO(later): Make it so symbols can be instrumented on any Control flow block + sdfg = dinstr.to_sdfg(simplify=False) sdfg.start_state.symbol_instrument = dace.DataInstrumentationType.Save A = np.ones((20, )) sdfg(A, j=15) From e97f5bc786a31c5090f93e36da67013b4ba7388d Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 10:40:22 +0200 Subject: [PATCH 071/108] Fix subgraph nesting --- dace/transformation/helpers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 4c6631a275..2ac4ac49a3 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -133,22 +133,21 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} # Find defined subgraph symbols - sdfg_symbols = sdfg.used_symbols(True) defined_symbols = set() strictly_defined_symbols = set() for e in is_edges: defined_symbols.update(set(e.data.assignments.keys())) for k, v in e.data.assignments.items(): try: - if k not in sdfg_symbols and k not in {str(a) for a in symbolic.pystr_to_symbolic(v).args}: + if k not in sdfg.symbols and k not in {str(a) for a in symbolic.pystr_to_symbolic(v).args}: strictly_defined_symbols.add(k) except AttributeError: # `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args` pass for b in all_blocks: if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '': - defined_symbols.update(b.loop_variable) - if b.loop_variable not in sdfg_symbols: + defined_symbols.add(b.loop_variable) + if b.loop_variable not in sdfg.symbols: if b.init_statement: init_assignment = loop_analysis.get_init_assignment(b) if b.loop_variable not in {str(s) for s in symbolic.pystr_to_symbolic(init_assignment).args}: @@ -220,8 +219,11 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS ndefined_symbols = set() out_mapping = {} out_state = None - for e in nsdfg.edges(): + for e in nsdfg.all_interstate_edges(): ndefined_symbols.update(set(e.data.assignments.keys())) + for b in all_blocks: + if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '' and b.init_statement: + ndefined_symbols.add(b.loop_variable) if ndefined_symbols: out_state = nsdfg.add_state('symbolic_output') nsdfg.add_edge(sink_node, out_state, InterstateEdge()) From 8de1c1ef80df4ed28ab72349b583b45369646fc1 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 13:39:29 +0200 Subject: [PATCH 072/108] Fixes to GPU codegen --- dace/codegen/targets/cuda.py | 8 +- dace/sdfg/state.py | 28 +++- dace/sdfg/utils.py | 13 +- .../interstate/gpu_transform_sdfg.py | 151 ++++++++++++------ 4 files changed, 140 insertions(+), 60 deletions(-) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 23d48fe9ea..58b9d974e9 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -156,8 +156,8 @@ def preprocess(self, sdfg: SDFG) -> None: # Find GPU<->GPU strided copies that cannot be represented by a single copy command from dace.transformation.dataflow import CopyToMap for e, state in list(sdfg.all_edges_recursive()): - nsdfg = state.parent if isinstance(e.src, nodes.AccessNode) and isinstance(e.dst, nodes.AccessNode): + nsdfg = state.parent if (e.src.desc(nsdfg).storage == dtypes.StorageType.GPU_Global and e.dst.desc(nsdfg).storage == dtypes.StorageType.GPU_Global): copy_shape, src_strides, dst_strides, _, _ = memlet_copy_to_absolute_strides( @@ -774,7 +774,7 @@ def increment(streams): state_streams = [] state_subsdfg_events = [] - for state in sdfg.nodes(): + for state in sdfg.states(): # Start by annotating source nodes source_nodes = state.source_nodes() @@ -872,7 +872,7 @@ def increment(streams): # Compute maximal number of events by counting edges (within the same # state) that point from one stream to another state_events = [] - for i, state in enumerate(sdfg.nodes()): + for i, state in enumerate(sdfg.states()): events = state_subsdfg_events[i] for e in state.edges(): @@ -2011,7 +2011,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S bidx = krange.coord_at(dsym) # handle dynamic map inputs - for e in dace.sdfg.dynamic_map_inputs(sdfg.states()[state_id], dfg_scope.source_nodes()[0]): + for e in dace.sdfg.dynamic_map_inputs(cfg.node(state_id), dfg_scope.source_nodes()[0]): kernel_stream.write( self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index ca733258df..47dc3b29c1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -726,8 +726,12 @@ def update_if_not_none(dic, update): defined_syms[str(sym)] = sym.dtype # Add inter-state symbols - for edge in sdfg.dfs_edges(sdfg.start_state): + if isinstance(sdfg.start_block, LoopRegion): + update_if_not_none(defined_syms, sdfg.start_block.new_symbols(defined_syms)) + for edge in sdfg.all_interstate_edges(): update_if_not_none(defined_syms, edge.data.new_symbols(sdfg, defined_syms)) + if isinstance(edge.dst, LoopRegion): + update_if_not_none(defined_syms, edge.dst.new_symbols(defined_syms)) # Add scope symbols all the way to the subgraph sdict = state.scope_dict() @@ -3208,6 +3212,28 @@ def _used_symbols_internal(self, return free_syms, defined_syms, used_before_assignment + def new_symbols(self, symbols) -> Dict[str, dtypes.typeclass]: + """ + Returns a mapping between the symbol defined by this loop and its type, if it exists. + """ + # Avoid cyclic import + from dace.codegen.tools.type_inference import infer_expr_type + from dace.transformation.passes.analysis import loop_analysis + + if self.init_statement and self.loop_variable: + alltypes = copy.copy(symbols) + alltypes.update({k: v.dtype for k, v in self.sdfg.arrays.items()}) + l_end = loop_analysis.get_loop_end(self) + l_start = loop_analysis.get_init_assignment(self) + l_step = loop_analysis.get_loop_stride(self) + inferred_type = dtypes.result_type_of(infer_expr_type(l_start, alltypes), + infer_expr_type(l_step, alltypes), + infer_expr_type(l_end, alltypes)) + init_rhs = loop_analysis.get_init_assignment(self) + if self.loop_variable not in symbolic.free_symbols_and_functions(init_rhs): + return {self.loop_variable: inferred_type} + return {} + def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index e92929c36f..502adb8bb7 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1507,9 +1507,14 @@ def _traverse(scope: Node, symbols: Dict[str, dtypes.typeclass]): def _tswds_cf_region( sdfg: SDFG, - region: ControlFlowRegion, + region: AbstractControlFlowRegion, symbols: Dict[str, dtypes.typeclass], recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: + if isinstance(region, ConditionalBlock): + for _, b in region.branches: + yield from _tswds_cf_region(sdfg, b, symbols, recursive) + return + # Add symbols from inter-state edges along the state machine start_region = region.start_block visited = set() @@ -1522,7 +1527,7 @@ def _tswds_cf_region( visited.add(edge.src) if isinstance(edge.src, SDFGState): yield from _tswds_state(sdfg, edge.src, {}, recursive) - elif isinstance(edge.src, ControlFlowRegion): + elif isinstance(edge.src, AbstractControlFlowRegion): yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) # Add edge symbols into defined symbols @@ -1534,14 +1539,14 @@ def _tswds_cf_region( visited.add(edge.dst) if isinstance(edge.dst, SDFGState): yield from _tswds_state(sdfg, edge.dst, symbols, recursive) - elif isinstance(edge.dst, ControlFlowRegion): + elif isinstance(edge.dst, AbstractControlFlowRegion): yield from _tswds_cf_region(sdfg, edge.dst, symbols, recursive) # If there is only one state, the DFS will miss it if start_region not in visited: if isinstance(start_region, SDFGState): yield from _tswds_state(sdfg, start_region, symbols, recursive) - elif isinstance(start_region, ControlFlowRegion): + elif isinstance(start_region, AbstractControlFlowRegion): yield from _tswds_cf_region(sdfg, start_region, symbols, recursive) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 2753844fc1..d710d58782 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -4,7 +4,9 @@ from dace import data, memlet, dtypes, sdfg as sd, subsets as sbs, propagate_memlets_sdfg from dace.sdfg import nodes, scope from dace.sdfg import utils as sdutil -from dace.sdfg.state import SDFGState +from dace.sdfg.replace import replace_in_codeblock +from dace.sdfg.sdfg import memlets_in_ast +from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState from dace.transformation import transformation, helpers as xfh from dace.properties import Property, make_properties from collections import defaultdict @@ -252,11 +254,23 @@ def apply(self, _, sdfg: sd.SDFG): if not found_full_write: input_nodes.append((onodename, onode)) + check_memlets: List[memlet.Memlet] = [] for edge in sdfg.all_interstate_edges(): - memlets = edge.data.get_read_memlets(sdfg.arrays) - for mem in memlets: - if sdfg.arrays[mem.data].storage == dtypes.StorageType.GPU_Global: - data_already_on_gpu[mem.data] = None + check_memlets.extend(edge.data.get_read_memlets(sdfg.arrays)) + for blk in sdfg.all_control_flow_blocks(): + if isinstance(blk, ConditionalBlock): + for c, _ in blk.branches: + if c is not None: + check_memlets.extend(memlets_in_ast(c.code[0], sdfg.arrays)) + elif isinstance(blk, LoopRegion): + check_memlets.extend(memlets_in_ast(blk.loop_condition.code[0], sdfg.arrays)) + if blk.init_statement: + check_memlets.extend(memlets_in_ast(blk.init_statement.code[0], sdfg.arrays)) + if blk.update_statement: + check_memlets.extend(memlets_in_ast(blk.update_statement.code[0], sdfg.arrays)) + for mem in check_memlets: + if sdfg.arrays[mem.data].storage == dtypes.StorageType.GPU_Global: + data_already_on_gpu[mem.data] = None # Replace nodes for state in sdfg.states(): @@ -473,63 +487,98 @@ def apply(self, _, sdfg: sd.SDFG): cloned_data = set(cloned_arrays.keys()).union(gpu_scalars.keys()).union(data_already_on_gpu.keys()) - for state in list(sdfg.states()): + def _create_copy_out(arrays_used: Set[str]) -> Dict[str, str]: + # Add copy-out nodes + name_mapping = {} + for nname in arrays_used: + # Handle GPU scalars + if nname in gpu_scalars: + hostname = gpu_scalars[nname] + if not hostname: + desc = sdfg.arrays[nname].clone() + desc.storage = dtypes.StorageType.CPU_Heap + desc.transient = True + hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) + gpu_scalars[nname] = hostname + else: + desc = sdfg.arrays[hostname] + devicename = nname + elif nname in data_already_on_gpu: + hostname = data_already_on_gpu[nname] + if not hostname: + desc = sdfg.arrays[nname].clone() + desc.storage = dtypes.StorageType.CPU_Heap + desc.transient = True + hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) + data_already_on_gpu[nname] = hostname + else: + desc = sdfg.arrays[hostname] + devicename = nname + else: + desc = sdfg.arrays[nname] + hostname = nname + devicename = cloned_arrays[nname] + + src_array = nodes.AccessNode(devicename, debuginfo=desc.debuginfo) + dst_array = nodes.AccessNode(hostname, debuginfo=desc.debuginfo) + co_state.add_node(src_array) + co_state.add_node(dst_array) + co_state.add_nedge(src_array, dst_array, + memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) + name_mapping[devicename] = hostname + return name_mapping + + for block in list(sdfg.all_control_flow_blocks()): arrays_used = set() - for e in state.parent_graph.out_edges(state): + for e in block.parent_graph.out_edges(block): # Used arrays = intersection between symbols and cloned data arrays_used.update(set(e.data.free_symbols) & cloned_data) # Create a state and copy out used arrays if len(arrays_used) > 0: - - co_state = state.parent_graph.add_state(state.label + '_icopyout') + co_state = block.parent_graph.add_state(block.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state - for e in state.parent_graph.out_edges(state): - sdutil.change_edge_src(state.parent_graph, state, co_state) + for e in block.parent_graph.out_edges(block): + sdutil.change_edge_src(block.parent_graph, block, co_state) # Add unconditional edge to interim state - state.parent_graph.add_edge(state, co_state, sd.InterstateEdge()) - - # Add copy-out nodes - for nname in arrays_used: - - # Handle GPU scalars - if nname in gpu_scalars: - hostname = gpu_scalars[nname] - if not hostname: - desc = sdfg.arrays[nname].clone() - desc.storage = dtypes.StorageType.CPU_Heap - desc.transient = True - hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) - gpu_scalars[nname] = hostname - else: - desc = sdfg.arrays[hostname] - devicename = nname - elif nname in data_already_on_gpu: - hostname = data_already_on_gpu[nname] - if not hostname: - desc = sdfg.arrays[nname].clone() - desc.storage = dtypes.StorageType.CPU_Heap - desc.transient = True - hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) - data_already_on_gpu[nname] = hostname - else: - desc = sdfg.arrays[hostname] - devicename = nname - else: - desc = sdfg.arrays[nname] - hostname = nname - devicename = cloned_arrays[nname] - - src_array = nodes.AccessNode(devicename, debuginfo=desc.debuginfo) - dst_array = nodes.AccessNode(hostname, debuginfo=desc.debuginfo) - co_state.add_node(src_array) - co_state.add_node(dst_array) - co_state.add_nedge(src_array, dst_array, - memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) - for e in state.parent_graph.out_edges(co_state): + block.parent_graph.add_edge(block, co_state, sd.InterstateEdge()) + mapping = _create_copy_out(arrays_used) + for devicename, hostname in mapping.items(): + for e in block.parent_graph.out_edges(co_state): e.data.replace(devicename, hostname, False) + for block in list(sdfg.all_control_flow_blocks()): + arrays_used = set() + if isinstance(block, ConditionalBlock): + for c, _ in block.branches: + if c is not None: + arrays_used.update(set(c.get_free_symbols()) & cloned_data) + elif isinstance(block, LoopRegion): + arrays_used.update(set(block.loop_condition.get_free_symbols()) & cloned_data) + if block.init_statement: + arrays_used.update(set(block.init_statement.get_free_symbols()) & cloned_data) + if block.update_statement: + arrays_used.update(set(block.update_statement.get_free_symbols()) & cloned_data) + else: + continue + + # Create a state and copy out used arrays + if len(arrays_used) > 0: + co_state = block.parent_graph.add_state_before(block, block.label + '_icopyout') + mapping = _create_copy_out(arrays_used) + for devicename, hostname in mapping.items(): + if isinstance(block, ConditionalBlock): + for c, _ in block.branches: + if c is not None: + replace_in_codeblock(c, {devicename: hostname}) + elif isinstance(block, LoopRegion): + replace_in_codeblock(block.loop_condition, {devicename: hostname}) + if block.init_statement: + replace_in_codeblock(block.init_statement, {devicename: hostname}) + if block.update_statement: + replace_in_codeblock(block.update_statement, {devicename: hostname}) + # Step 9: Simplify if not self.simplify: return From cb80f0b3288cedd7ea89acb035cae7873f7afe2f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 14:34:04 +0200 Subject: [PATCH 073/108] Fixed `move_loop_into_map_test.py::MoveLoopIntoMapTest::test_more_than_a_map`. Because the filtering is now not applied, like it was originally, the transformation does no longer apply. However, this is the historical expected behaviour. --- .../move_loop_into_map_test.py | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index fbb05d30f5..ad51941cb0 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -150,7 +150,9 @@ def test_apply_multiple_times_1(self): def test_more_than_a_map(self): """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. - However, there is no dependency. + + Note that there is actually no dependency, however, the transformation, because it relies + on `SDFGState.read_and_write_sets()` it can not detect this and can thus not be applied. """ sdfg = dace.SDFG('more_than_a_map') _, aarr = sdfg.add_array('A', (3, 3), dace.float64) @@ -175,26 +177,8 @@ def test_more_than_a_map(self): body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - sdfg_args_ref = { - "A": np.array(np.random.rand(3, 3), dtype=np.float64), - "B": np.array(np.random.rand(3, 3), dtype=np.float64), - "out": np.array(np.random.rand(3, 3), dtype=np.float64), - } - sdfg_args_res = copy.deepcopy(sdfg_args_ref) - - # Perform the reference execution - sdfg(**sdfg_args_ref) - - # Apply the transformation and execute the SDFG again. count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) - sdfg(**sdfg_args_res) - - for name in sdfg_args_ref.keys(): - self.assertTrue( - np.allclose(sdfg_args_ref[name], sdfg_args_res[name]), - f"Miss match for {name}", - ) - self.assertTrue(count > 0) + self.assertTrue(count == 0) def test_more_than_a_map_1(self): """ From b704a4378cd8b03a1d1e5469458118c2fd70ce2d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 14:39:03 +0200 Subject: [PATCH 074/108] Fixed `prune_connectors_test.py::test_read_write_*`. Because of the removed filtering the transformat can no longer apply. However, I originally added these tests to demonstrate this inconsistent behaviour in the first place. So removing them is now the right choice. This commit also combines them to `prune_connectors_test.py::test_read_write`. --- .../transformations/prune_connectors_test.py | 25 +++---------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 9995fbd305..b7b287d77e 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -331,16 +331,6 @@ def test_unused_retval_2(): assert np.allclose(a, 1) -def test_read_write_1(): - # Because the memlet is conforming, we can apply the transformation. - sdfg = _make_read_write_sdfg(True) - - assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False) - - - - - def test_prune_connectors_with_dependencies(): sdfg = dace.SDFG('tester') A, A_desc = sdfg.add_array('A', [4], dace.float64) @@ -419,21 +409,12 @@ def test_prune_connectors_with_dependencies(): assert np.allclose(np_d, np_d_) -def test_read_write_1(): +def test_read_write(): sdfg, nsdfg = _make_read_write_sdfg(True) + assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) - assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) - sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True) - - -def test_read_write_2(): - # In previous versions of DaCe the transformation could not be applied in the - # case of a not conforming Memlet. - # See [PR#1678](https://github.com/spcl/dace/pull/1678) sdfg, nsdfg = _make_read_write_sdfg(False) - - assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) - sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True) + assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) if __name__ == "__main__": From f74d6e81f979483ce849606f5911c9ece3adec3f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 14:46:27 +0200 Subject: [PATCH 075/108] General improvements to some tests. Ensured that the return value is always accassable in the same way. Also ensured that the `test_rna_read_and_write_sets_different_storage()` test verifies that it still gives the same result. --- .../refine_nested_access_test.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index 94c02a88de..38e17ad3af 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -177,12 +177,12 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False) A = state.add_access("A") - T1_input = state.add_access("T1") + T1_output = state.add_access("T1") if diff_in_out: sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False) - T1_output = state.add_access("T2") + T1_input = state.add_access("T2") else: - T1_output = state.add_access("T1") + T1_input = state.add_access("T1") tsklt = state.add_tasklet( "comp", @@ -195,7 +195,7 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: # An alternative would be to write to a different location here. # Then, the data would be added to the access node. state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]")) - state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet("T1[0]")) + state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet(T1_input.data + "[0]")) state.add_edge(tsklt, "__out", T1_output, None, dace.Memlet(T1_output.data + "[1]")) return sdfg @@ -254,6 +254,16 @@ def test_rna_read_and_write_sets_different_storage(): ) assert nb_applied > 0 + args = { + "A": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T2": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T1": np.zeros(2, dtype=np.float64), + } + ref = args["A"][0] + args["A"][1] + sdfg(**args) + res = args["T1"][1] + assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'." + if __name__ == '__main__': test_refine_dataflow() From e1039243d847f964bdec5ced1b2cdebeb1c977c5 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 14:50:14 +0200 Subject: [PATCH 076/108] Updated `refine_nested_access_test.py::test_rna_read_and_write_sets_doule_use` I realized that the transformation should not apply. The reason is because of all arguments to the nested SDFG element `0` is accessed. This means it can not be adjusted. --- tests/transformations/refine_nested_access_test.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index 38e17ad3af..81640665ed 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -228,18 +228,15 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: def test_rna_read_and_write_sets_doule_use(): - """ - NOTE: Under the current definition of the read/write sets this test will fail. - """ - - # The output is used also as temporary storage. + # The transformation does not apply because we access element `0` of both arrays that we + # pass inside the nested SDFG. sdfg = _make_rna_read_and_write_set_sdfg(False) nb_applied = sdfg.apply_transformations_repeated( [RefineNestedAccess], validate=True, validate_all=True, ) - assert nb_applied > 0 + assert nb_applied == 0 def test_rna_read_and_write_sets_different_storage(): From 56e756df72ea63642e75403e4f771ac5ab932ac7 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 14:28:57 +0200 Subject: [PATCH 077/108] More GPU fixes --- dace/codegen/instrumentation/gpu_events.py | 2 +- dace/codegen/targets/cuda.py | 4 +- dace/transformation/interstate/loop_to_map.py | 6 +-- .../subgraph/gpu_persistent_fusion.py | 39 ++++++++++++++----- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/dace/codegen/instrumentation/gpu_events.py b/dace/codegen/instrumentation/gpu_events.py index 3d367d444e..6e0d483a43 100644 --- a/dace/codegen/instrumentation/gpu_events.py +++ b/dace/codegen/instrumentation/gpu_events.py @@ -105,7 +105,7 @@ def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, loc if state.instrument == dtypes.InstrumentationType.GPU_Events: idstr = self._idstr(cfg, state, None) local_stream.write(self._record_event('e' + idstr, 0), cfg, state_id) - local_stream.write(self._report('State %s' % state.label, sdfg, state), cfg, state_id) + local_stream.write(self._report('State %s' % state.label, cfg, state), cfg, state_id) local_stream.write(self._destroy_event('b' + idstr), cfg, state_id) local_stream.write(self._destroy_event('e' + idstr), cfg, state_id) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 58b9d974e9..bad937518d 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1434,7 +1434,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub create_grid_barrier = True self.create_grid_barrier = create_grid_barrier - kernel_name = '%s_%d_%d_%d' % (scope_entry.map.label, sdfg.cfg_id, state.block_id, state.node_id(scope_entry)) + kernel_name = '%s_%d_%d_%d' % (scope_entry.map.label, cfg.cfg_id, state.block_id, state.node_id(scope_entry)) # Comprehend grid/block dimensions from scopes grid_dims, block_dims, tbmap, dtbmap, _ = self.get_kernel_dimensions(dfg_scope) @@ -2061,7 +2061,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S assert CUDACodeGen._in_device_code is False CUDACodeGen._in_device_code = True self._kernel_map = node - self._kernel_state = sdfg.node(state_id) + self._kernel_state = cfg.node(state_id) self._block_dims = block_dims self._grid_dims = grid_dims diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 8d74248270..4495996b8a 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -394,7 +394,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Create NestedSDFG and add the loop contents to it. Gaher symbols defined in the NestedSDFG. fsymbols = set(sdfg.free_symbols) - body = graph.add_state('single_state_body', is_start_block=(graph.start_block is self.loop)) + body = graph.add_state_before(self.loop, 'single_state_body') nsdfg = SDFG('loop_body', constants=sdfg.constants_prop, parent=body) nsdfg.add_node(self.loop.start_block, is_start_block=True) nsymbols = dict() @@ -554,11 +554,9 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): if not source_nodes and not sink_nodes: body.add_nedge(entry, exit, memlet.Memlet()) - # Redirect edges connected to the loop to connect to the body state instead. + # Redirect outgoing edges connected to the loop to connect to the body state instead. for e in graph.out_edges(self.loop): graph.add_edge(body, e.dst, e.data) - for e in graph.in_edges(self.loop): - graph.add_edge(e.src, body, e.data) # Delete the loop and connected edges. graph.remove_node(self.loop) diff --git a/dace/transformation/subgraph/gpu_persistent_fusion.py b/dace/transformation/subgraph/gpu_persistent_fusion.py index ff4812d0af..b7c201a3d7 100644 --- a/dace/transformation/subgraph/gpu_persistent_fusion.py +++ b/dace/transformation/subgraph/gpu_persistent_fusion.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy import dace -from dace import dtypes, nodes, registry, Memlet +from dace import nodes, Memlet from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.dtypes import StorageType, ScheduleType from dace.properties import Property, make_properties -from dace.sdfg.utils import concurrent_subgraphs +from dace.sdfg.state import AbstractControlFlowRegion, LoopRegion from dace.sdfg.graph import SubgraphView from dace.transformation.transformation import SubgraphTransformation @@ -68,12 +68,18 @@ class GPUPersistentKernel(SubgraphTransformation): @staticmethod def can_be_applied(sdfg: SDFG, subgraph: SubgraphView): - if not set(subgraph.nodes()).issubset(set(sdfg.nodes())): return False + subgraph_blocks = set() + for nd in subgraph.nodes(): + subgraph_blocks.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + subgraph_blocks.update(nd.all_control_flow_blocks()) + subgraph_states = set([blk for blk in subgraph_blocks if isinstance(blk, SDFGState)]) + # All states need to be GPU states - for state in subgraph: + for state in subgraph_states: if not GPUPersistentKernel.is_gpu_state(sdfg, state): return False @@ -114,6 +120,12 @@ def can_be_applied(sdfg: SDFG, subgraph: SubgraphView): def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) + subgraph_blocks = set() + for nd in subgraph.nodes(): + subgraph_blocks.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + subgraph_blocks.update(nd.all_control_flow_blocks()) + entry_states_in, entry_states_out = self.get_entry_states(sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) @@ -181,13 +193,22 @@ def apply(self, sdfg: SDFG): new_symbols.add(k) if k in sdfg.symbols and k not in kernel_sdfg.symbols: kernel_sdfg.add_symbol(k, sdfg.symbols[k]) + for blk in subgraph_blocks: + if isinstance(blk, LoopRegion): + if blk.loop_variable and blk.init_statement: + new_symbols.add(blk.loop_variable) + if blk.loop_variable in sdfg.symbols and blk.loop_variable not in kernel_sdfg.symbols: + kernel_sdfg.add_symbol(blk.loop_variable, sdfg.symbols[blk.loop_variable]) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) - for state in subgraph: - state.parent = kernel_sdfg + for nd in subgraph: + nd.sdfg = kernel_sdfg + if isinstance(nd, AbstractControlFlowRegion): + for n in nd.all_control_flow_blocks(): + n.sdfg = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states @@ -203,7 +224,7 @@ def apply(self, sdfg: SDFG): sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel - kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() + kernel_data = set(node.data for state in kernel_sdfg.states() for node in state.nodes() if isinstance(node, nodes.AccessNode)) other_data = set(node.data for state in other_states for node in state.nodes() if isinstance(node, nodes.AccessNode)) @@ -230,7 +251,7 @@ def apply(self, sdfg: SDFG): kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ - state.in_degree(node) == 0 for state in kernel_sdfg for node in state + state.in_degree(node) == 0 for state in kernel_sdfg.states() for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): From 6e14e6d9964951fb98892d160c862307cfefff7f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 17:35:31 +0200 Subject: [PATCH 078/108] More fixes --- dace/transformation/auto/auto_optimize.py | 2 ++ dace/transformation/interstate/loop_to_map.py | 7 ++++++- .../transformation/interstate/move_loop_into_map.py | 2 ++ dace/transformation/passes/analysis/analysis.py | 13 +++++-------- .../transformation/passes/analysis/loop_analysis.py | 2 ++ .../passes/dead_dataflow_elimination.py | 4 +--- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 1c77160338..1a9e8ba30b 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -662,6 +662,8 @@ def auto_optimize(sdfg: SDFG, print("Specializing the SDFG for symbols", known_symbols) sdfg.specialize(known_symbols) + sdfg.reset_cfg_list() + # Validate at the end if validate or validate_all: sdfg.validate() diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 4495996b8a..55327af5fb 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -11,7 +11,7 @@ from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowRegion, LoopRegion, ReturnBlock import dace.transformation.helpers as helpers from dace.transformation import transformation as xf from dace.transformation.passes.analysis import loop_analysis @@ -94,6 +94,11 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): if start is None or end is None or step is None or itervar is None: return False + # Loops containing break, continue, or returns may not be turned into a map. + for blk in self.loop.all_control_flow_blocks(): + if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)): + return False + # We cannot handle symbols read from data containers unless they are scalar. for expr in (start, end, step): if symbolic.contains_sympy_functions(expr): diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index 3ac1f5a9e9..fd7c4353dd 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -229,6 +229,8 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): if helpers.is_symbol_unused(sdfg, s): sdfg.remove_symbol(s) + sdfg.reset_cfg_list() + from dace.transformation.interstate import RefineNestedAccess transformation = RefineNestedAccess() transformation.setup_match(sdfg, body.parent_graph.cfg_id, body.block_id, diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index e673337a35..d41f36c387 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -235,13 +235,12 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If access nodes were modified, reapply return modified & ppl.Modifies.AccessNodes - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]: """ :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. """ - top_result: Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]] = {} + result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} arrays: Set[str] = set(sdfg.arrays.keys()) for block in sdfg.all_control_flow_blocks(): readset, writeset = set(), set() @@ -282,9 +281,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tupl if fsyms: result[e.src][0].update(fsyms) result[e.dst][0].update(fsyms) - - top_result[sdfg.cfg_id] = result - return top_result + return result @properties.make_properties @@ -557,6 +554,8 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i """ top_result: Dict[int, WriteScopeDict] = dict() + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[AccessSets.__name__] + for sdfg in top_sdfg.all_sdfgs_recursive(): result: WriteScopeDict = defaultdict(lambda: defaultdict(lambda: set())) idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} @@ -576,8 +575,6 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i all_doms_transitive[k].add(cfg) all_doms_transitive[k].update(all_doms_transitive[cfg]) - access_sets: Dict[ControlFlowBlock, Tuple[Set[str], - Set[str]]] = pipeline_results[AccessSets.__name__][sdfg.cfg_id] access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[ FindAccessNodes.__name__][sdfg.cfg_id] diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index ec9d4d0c73..69a77422e8 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -16,6 +16,8 @@ def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: """ Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). """ + if loop.loop_variable is None or loop.loop_variable == '': + return None end: Optional[symbolic.SymbolicType] = None a = sympy.Wild('a') condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index e429fa902d..9bc8e27dd2 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -64,9 +64,7 @@ def apply(self, region, pipeline_results): reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[ ap.ControlFlowBlockReachability.__name__ ][region.cfg_id] - access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[ - ap.AccessSets.__name__ - ][sdfg.cfg_id] + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[ap.AccessSets.__name__] result: Dict[SDFGState, Set[str]] = defaultdict(set) # Traverse region backwards From 0c6535920c6876a3aa7432c68d285d93b98549e0 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 18:56:43 +0200 Subject: [PATCH 079/108] More bugfixes --- dace/codegen/targets/fpga.py | 4 ++-- dace/sdfg/sdfg.py | 4 ++-- dace/transformation/interstate/fpga_transform_state.py | 2 +- dace/transformation/interstate/state_elimination.py | 2 +- dace/transformation/testing.py | 6 ++++-- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 0c74d6ec07..6955c838cb 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -1105,7 +1105,7 @@ def generate_nested_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: dace. self._dispatcher.dispatch_subgraph(sdfg, cfg, sg, - sdfg.node_id(state), + cfg.node_id(state), function_stream, callsite_stream, skip_entry_node=False) @@ -1974,7 +1974,7 @@ def _is_innermost(self, scope, scope_dict, sdfg): return False to_search += scope_dict[x] elif isinstance(x, dace.sdfg.nodes.NestedSDFG): - for state in x.sdfg: + for state in x.sdfg.states(): if not self._is_innermost(state.nodes(), state.scope_children(), x.sdfg): return False return True diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 38a41236a6..a6c0d99dc6 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1327,7 +1327,7 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set = set() write_set = set() for state in self.states(): - for edge in self.in_edges(state): + for edge in state.parent_graph.in_edges(state): read_set |= edge.data.free_symbols & self.arrays.keys() # Get dictionaries of subsets read and written from each state rs, ws = state._read_and_write_sets() @@ -1483,7 +1483,7 @@ def transients(self): result = {} tstate = {} - for (i, state) in enumerate(self.nodes()): + for (i, state) in enumerate(self.states()): scope_dict = state.scope_dict() for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.desc(self).transient: diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index 6e1af4ed16..287b725abb 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -27,7 +27,7 @@ def fpga_update(sdfg: SDFG, state: SDFGState, depth: int): if (hasattr(node, "schedule") and node.schedule == dace.dtypes.ScheduleType.Default): node.schedule = dace.dtypes.ScheduleType.FPGA_Device if isinstance(node, nodes.NestedSDFG): - for s in node.sdfg.nodes(): + for s in node.sdfg.states(): fpga_update(node.sdfg, s, depth + 1) diff --git a/dace/transformation/interstate/state_elimination.py b/dace/transformation/interstate/state_elimination.py index 6ffe9fa468..8755155615 100644 --- a/dace/transformation/interstate/state_elimination.py +++ b/dace/transformation/interstate/state_elimination.py @@ -78,7 +78,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): state = self.start_state # The transformation applies only to nested SDFGs - if not graph.parent: + if not isinstance(graph, SDFG) or not graph.parent: return False # Only empty states can be eliminated diff --git a/dace/transformation/testing.py b/dace/transformation/testing.py index 79738c9ec3..dea1be2a9b 100644 --- a/dace/transformation/testing.py +++ b/dace/transformation/testing.py @@ -6,6 +6,7 @@ import traceback from dace.sdfg import SDFG +from dace.sdfg.state import ControlFlowRegion from dace.transformation.optimizer import Optimizer @@ -68,8 +69,9 @@ def _optimize_recursive(self, sdfg: SDFG, depth: int): print(' ' * depth, type(match).__name__, '- ', end='', file=self.stdout) - tsdfg: SDFG = new_sdfg.cfg_list[match.cfg_id] - tgraph = tsdfg.node(match.state_id) if match.state_id >= 0 else tsdfg + tcfg: ControlFlowRegion = new_sdfg.cfg_list[match.cfg_id] + tsdfg = tcfg.sdfg if not isinstance(tcfg, SDFG) else tcfg + tgraph = tcfg.node(match.state_id) if match.state_id >= 0 else tcfg match._sdfg = tsdfg match.apply(tgraph, tsdfg) From ac72cb137e78d34c6b305c3de875e6a1ea660448 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 23 Oct 2024 19:45:07 +0200 Subject: [PATCH 080/108] Fixes --- dace/transformation/dataflow/mpi.py | 6 +++--- dace/transformation/dataflow/tiling.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/mpi.py b/dace/transformation/dataflow/mpi.py index c44c21e9b9..e838c61648 100644 --- a/dace/transformation/dataflow/mpi.py +++ b/dace/transformation/dataflow/mpi.py @@ -102,7 +102,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): rangeexpr = str(map_entry.map.range.num_elements()) stripmine_subgraph = {StripMining.map_entry: self.subgraph[MPITransformMap.map_entry]} - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id stripmine = StripMining() stripmine.setup_match(sdfg, cfg_id, self.state_id, stripmine_subgraph, self.expr_index) stripmine.dim_idx = -1 @@ -128,7 +128,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): LocalStorage.node_a: graph.node_id(outer_map), LocalStorage.node_b: self.subgraph[MPITransformMap.map_entry] } - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id in_local_storage = InLocalStorage() in_local_storage.setup_match(sdfg, cfg_id, self.state_id, in_local_storage_subgraph, self.expr_index) in_local_storage.array = e.data.data @@ -146,7 +146,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): LocalStorage.node_a: graph.node_id(in_map_exit), LocalStorage.node_b: graph.node_id(out_map_exit) } - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id outlocalstorage = OutLocalStorage() outlocalstorage.setup_match(sdfg, cfg_id, self.state_id, outlocalstorage_subgraph, self.expr_index) outlocalstorage.array = name diff --git a/dace/transformation/dataflow/tiling.py b/dace/transformation/dataflow/tiling.py index bfa899e71a..b9e4a59b0c 100644 --- a/dace/transformation/dataflow/tiling.py +++ b/dace/transformation/dataflow/tiling.py @@ -54,7 +54,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): from dace.transformation.dataflow.map_collapse import MapCollapse from dace.transformation.dataflow.strip_mining import StripMining stripmine_subgraph = {StripMining.map_entry: self.subgraph[MapTiling.map_entry]} - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id last_map_entry = None removed_maps = 0 From cc9e5f4ab43cf05dd759b3518146266208675b8a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 24 Oct 2024 10:37:22 +0200 Subject: [PATCH 081/108] FPGA fixes --- dace/codegen/targets/fpga.py | 18 ++++++++++-------- dace/sdfg/state.py | 2 +- dace/transformation/interstate/block_fusion.py | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 6955c838cb..92d4bdd741 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -421,10 +421,11 @@ def find_rtl_tasklet(self, subgraph: ScopeSubgraphView): ''' for n in subgraph.nodes(): if isinstance(n, dace.nodes.NestedSDFG): - for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): - node = self.find_rtl_tasklet(sg) - if node: - return node + if len(n.sdfg.nodes()) == 1 and isinstance(n.sdfg.nodes()[0], SDFGState): + for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): + node = self.find_rtl_tasklet(sg) + if node: + return node elif isinstance(n, dace.nodes.Tasklet) and n.language == dace.dtypes.Language.SystemVerilog: return n return None @@ -438,9 +439,10 @@ def is_multi_pumped_subgraph(self, subgraph: ScopeSubgraphView): ''' for n in subgraph.nodes(): if isinstance(n, dace.nodes.NestedSDFG): - for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): - if self.is_multi_pumped_subgraph(sg): - return True + if len(n.sdfg.nodes()) == 1 and isinstance(n.sdfg.nodes()[0], SDFGState): + for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.nodes()[0]): + if self.is_multi_pumped_subgraph(sg): + return True elif isinstance(n, dace.nodes.MapEntry) and n.schedule == dace.ScheduleType.FPGA_Multi_Pumped: return True return False @@ -1720,7 +1722,7 @@ def _emit_copy(self, sdfg: SDFG, cfg: ControlFlowRegion, state_id: int, src_node raise NotImplementedError("Reads from shift registers only supported from tasklets.") # Try to turn into degenerate/strided ND copies - state_dfg = sdfg.nodes()[state_id] + state_dfg = cfg.node(state_id) copy_shape, src_strides, dst_strides, src_expr, dst_expr = (cpp.memlet_copy_to_absolute_strides( self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, packed_types=True)) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 4a7e88a8b3..4c9ef1f248 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2999,7 +2999,7 @@ def start_block(self, block_id): if block_id < 0 or block_id >= self.number_of_nodes(): raise ValueError('Invalid state ID') self._start_block = block_id - self._cached_start_block = self.node(block_id) + self._cached_start_block = None @make_properties diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py index 6abd65fc87..71736fc269 100644 --- a/dace/transformation/interstate/block_fusion.py +++ b/dace/transformation/interstate/block_fusion.py @@ -78,6 +78,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True def apply(self, graph: ControlFlowRegion, sdfg): + first_is_start = graph.start_block is self.first_block connecting_edge = graph.edges_between(self.first_block, self.second_block)[0] assignments_to_absorb = connecting_edge.data.assignments graph.remove_edge(connecting_edge) @@ -87,12 +88,11 @@ def apply(self, graph: ControlFlowRegion, sdfg): if self._is_noop(self.first_block): # We remove the first block and let the second one remain. - first_is_start = graph.start_block is self.first_block for ie in graph.in_edges(self.first_block): graph.add_edge(ie.src, self.second_block, ie.data) - graph.remove_node(self.first_block) if first_is_start: graph.start_block = self.second_block.block_id + graph.remove_node(self.first_block) else: # We remove the second block and let the first one remain. for oe in graph.out_edges(self.second_block): From bc9f61ea25aca5b7a608f2e4ad27565e9c44367e Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 29 Oct 2024 10:56:58 +0100 Subject: [PATCH 082/108] Adapt state propagation into a pass to adapt it --- dace/sdfg/propagation.py | 13 +- dace/sdfg/state.py | 17 +- .../transformation/interstate/sdfg_nesting.py | 4 + dace/transformation/pass_pipeline.py | 4 +- .../passes/analysis/analysis.py | 167 +++++++++++++++++- .../simplification/control_flow_raising.py | 6 +- tests/state_propagation_test.py | 136 +++++++++++--- 7 files changed, 301 insertions(+), 46 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index ab23bd540c..5e9f182437 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -675,8 +675,8 @@ def _annotate_loop_ranges(sdfg: 'SDFG', unannotated_cycle_states): loop_states = sdutils.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard) for v in loop_states: - v.ranges[itervar] = subsets.Range([rng]) - guard.ranges[itervar] = subsets.Range([rng]) + v.ranges[str(itervar)] = subsets.Range([rng]) + guard.ranges[str(itervar)] = subsets.Range([rng]) condition_edges[guard] = sdfg.edges_between(guard, begin)[0] guard.is_loop_guard = True guard.itvar = itervar @@ -744,6 +744,15 @@ def propagate_states(sdfg: 'SDFG', concretize_dynamic_unbounded: bool = False) - :note: This operates on the SDFG in-place. """ + if sdfg.using_experimental_blocks: + # Avoid cyclic imports + from dace.transformation.pass_pipeline import Pipeline + from dace.transformation.passes.analysis import StatePropagation + + state_prop_pipeline = Pipeline([StatePropagation()]) + state_prop_pipeline.apply_pass(sdfg, {}) + return + # We import here to avoid cyclic imports. from dace.sdfg import InterstateEdge from dace.sdfg.analysis import cfg diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 4c9ef1f248..d5083fc99b 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1157,6 +1157,12 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): pre_conditions = DictProperty(key_type=str, value_type=list, desc='Pre-conditions for this block') post_conditions = DictProperty(key_type=str, value_type=list, desc='Post-conditions for this block') invariant_conditions = DictProperty(key_type=str, value_type=list, desc='Invariant conditions for this block') + ranges = DictProperty(key_type=str, value_type=Range, default={}, + desc='Variable ranges across this block, typically within loops') + + executions = SymbolicProperty(default=0, + desc="The number of times this block gets executed (0 stands for unbounded)") + dynamic_executions = Property(dtype=bool, default=True, desc="The number of executions of this block is dynamic") _label: str @@ -1291,17 +1297,6 @@ class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlo symbol_instrument_condition = CodeProperty(desc="Condition under which to trigger the symbol instrumentation", default=CodeBlock("1", language=dtypes.Language.CPP)) - executions = SymbolicProperty(default=0, - desc="The number of times this state gets " - "executed (0 stands for unbounded)") - dynamic_executions = Property(dtype=bool, default=True, desc="The number of executions of this state " - "is dynamic") - - ranges = DictProperty(key_type=symbolic.symbol, - value_type=Range, - default={}, - desc='Variable ranges, typically within loops') - location = DictProperty(key_type=str, value_type=symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index ff096a3198..9c65077572 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -16,6 +16,7 @@ from dace.sdfg.graph import MultiConnectorEdge, SubgraphView from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil, infer_types, propagation +from dace.sdfg.state import LoopRegion from dace.transformation import transformation, helpers from dace.properties import make_properties, Property from dace import data @@ -1285,6 +1286,9 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: for e in nested_sdfg.edges(): defined_syms |= set(e.data.new_symbols(sdfg, {}).keys()) + for blk in nested_sdfg.all_control_flow_blocks(): + if isinstance(blk, LoopRegion): + defined_syms |= set(blk.new_symbols({}).keys()) defined_syms |= set(nested_sdfg.constants.keys()) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index e558ab0b20..bca7626b85 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -277,6 +277,8 @@ class ControlFlowRegionPass(Pass): apply_to_conditionals = properties.Property(dtype=bool, default=False, desc='Whether or not to apply to conditional blocks. If false, do ' + 'not apply to conditional blocks, but only their children.') + top_down = properties.Property(dtype=bool, default=False, + desc='Whether or not to apply top down (i.e., parents before children)') def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[int, Optional[Any]]]: """ @@ -290,7 +292,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D if nothing was returned. """ result = {} - for region in sdfg.all_control_flow_regions(recursive=True, parent_first=False): + for region in sdfg.all_control_flow_regions(recursive=True, parent_first=self.top_down): if isinstance(region, ConditionalBlock) and not self.apply_to_conditionals: continue retval = self.apply(region, pipeline_results) diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index d41f36c387..053b2f997e 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -1,14 +1,17 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from collections import defaultdict +from collections import defaultdict, deque + +import sympy from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.subsets import Range from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd from dace.sdfg.analysis import cfg as cfg_analysis -from typing import Dict, Set, Tuple, Any, Optional, Union +from typing import Dict, Iterable, List, Set, Tuple, Any, Optional, Union import networkx as nx from networkx.algorithms import shortest_paths as nxsp @@ -754,3 +757,163 @@ def apply_pass(self, sdfg: SDFG, _) -> Tuple[Dict[str, Set[str]], Dict[str, Set[ invariants: Dict[str, Set[str]] = {} self._derive_parameter_datasize_constraints(sdfg, invariants) return {}, invariants, {} + + +@transformation.experimental_cfg_block_compatible +class StatePropagation(ppl.ControlFlowRegionPass): + """ + Analyze a control flow region to determine the number of times each block inside of it is executed in the form of a + symbolic expression, or a concrete number where possible. + Each control flow block is marked with a symbolic expression for the number of executions, and a boolean flag to + indicate whether the number of executions is dynamic or not. A combination of dynamic being set to true and the + number of executions being 0 indicates that the number of executions is dynamically unbounded. + Additionally, the pass annotates each block with a `ranges` property, which indicates for loop variables defined + at that block what range of values the variable may take on. + Note: This path directly annotates the graph. + This pass supersedes `dace.sdfg.propagation.propagate_states` and is based on its algorithm, with significant + simplifications thanks to the use of control flow regions. + """ + + CATEGORY: str = 'Analysis' + + def __init__(self): + super().__init__() + self.top_down = True + self.apply_to_conditionals = True + + def depends_on(self): + return {ControlFlowBlockReachability} + + def _propagate_in_cfg(self, cfg: ControlFlowRegion, reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]], + starting_executions: int, starting_dynamic_executions: bool): + visited_blocks: Set[ControlFlowBlock] = set() + traversal_q: deque[Tuple[ControlFlowBlock, int, bool, List[str]]] = deque() + traversal_q.append((cfg.start_block, starting_executions, starting_dynamic_executions, [])) + while traversal_q: + (block, proposed_executions, proposed_dynamic, itvar_stack) = traversal_q.pop() + out_edges = cfg.out_edges(block) + if block in visited_blocks: + # This block has already been visited, meaning there are multiple paths towards this block. + if proposed_executions == 0 and proposed_dynamic: + block.executions = 0 + block.dynamic_executions = True + else: + block.executions = sympy.Max(block.executions, proposed_executions).doit() + block.dynamic_executions = (block.dynamic_executions or proposed_dynamic) + elif proposed_dynamic and proposed_executions == 0: + # We're propagating a dynamic unbounded number of executions, which always gets propagated + # unconditionally. Propagate to all children. + visited_blocks.add(block) + block.executions = proposed_executions + block.dynamic_executions = proposed_dynamic + # This gets pushed through to all children unconditionally. + if len(out_edges) > 0: + for oedge in out_edges: + traversal_q.append((oedge.dst, proposed_executions, proposed_dynamic, itvar_stack)) + else: + # If the state hasn't been visited yet and we're not propagating a dynamic unbounded number of + # executions, we calculate the number of executions for the next state(s) and continue propagating. + visited_blocks.add(block) + block.executions = proposed_executions + block.dynamic_executions = proposed_dynamic + if len(out_edges) == 1: + # Continue with the only child state. + if not out_edges[0].data.is_unconditional(): + # If the transition to the child state is based on a condition, this state could be an implicit + # exit state. The child state's number of executions is thus only given as an upper bound and + # marked as dynamic. + proposed_dynamic = True + traversal_q.append((out_edges[0].dst, proposed_executions, proposed_dynamic, itvar_stack)) + elif len(out_edges) > 1: + # Conditional split + for oedge in out_edges: + traversal_q.append((oedge.dst, block.executions, True, itvar_stack)) + + # Check if the CFG contains any cycles. Any cycles left in the graph (after control flow raising) are + # irreducible control flow and thus lead to a dynamically unbounded number of executions. Mark any block + # inside and reachable from any block inside the cycle as dynamically unbounded, irrespectively of what it was + # marked as before. + cycles: Iterable[Iterable[ControlFlowBlock]] = cfg.find_cycles() + for cycle in cycles: + for blk in cycle: + blk.executions = 0 + blk.dynamic_executions = True + for reached in reachable[blk]: + reached.executions = 0 + blk.dynamic_executions = True + + def apply(self, region, pipeline_results) -> None: + if isinstance(region, ConditionalBlock): + # In a conditional block, each branch is executed up to as many times as the conditional block itself is. + # TODO(later): We may be able to derive ranges here based on the branch conditions too. + for _, b in region.branches: + b.executions = region.executions + b.dynamic_executions = True + b.ranges = region.ranges + else: + if isinstance(region, SDFG): + # The root SDFG is executed exactly once, any other, nested SDFG is executed as many times as the parent + # state is. + if region is region.root_sdfg: + region.executions = 1 + region.dynamic_executions = False + elif region.parent: + region.executions = region.parent.executions + region.dynamic_executions = region.parent.dynamic_executions + + # Clear existing annotations. + for blk in region.nodes(): + blk.executions = 0 + blk.dynamic_executions = True + blk.ranges = region.ranges + + # Determine the number of executions for the start block within this region. In the case of loops, this + # is dependent on the number of loop iterations - where they can be determined. Where they may not be + # determined, the number of iterations is assumed to be dynamically unbounded. For any other control flow + # region, the start block is executed as many times as the region itself is. + starting_execs = region.executions + starting_dynamic = region.dynamic_executions + if isinstance(region, LoopRegion): + # If inside a loop, add range information if possible. + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + if start is not None and stop is not None and stride is not None and region.loop_variable: + # This inequality needs to be checked exactly like this due to constraints in sympy/symbolic + # expressions, do not simplify! + if (stride < 0) == True: + rng = (stop, start, -stride) + else: + rng = (start, stop, stride) + for blk in region.nodes(): + blk.ranges[str(region.loop_variable)] = Range([rng]) + + # Get surrounding iteration variables for the case of nested loops. + itvar_stack = [] + par = region.parent_graph + while par is not None and not isinstance(par, SDFG): + if isinstance(par, LoopRegion) and par.loop_variable: + itvar_stack.append(par.loop_variable) + par = par.parent_graph + + # Calculate the number of loop executions. + # This resolves ranges based on the order of iteration variables from surrounding loops. + loop_executions = sympy.ceiling(((stop + 1) - start) / stride) + for outer_itvar_string in itvar_stack: + outer_range = region.ranges[outer_itvar_string] + outer_start = outer_range[0][0] + outer_stop = outer_range[0][1] + outer_stride = outer_range[0][2] + outer_itvar = symbolic.pystr_to_symbolic(outer_itvar_string) + exec_repl = loop_executions.subs({outer_itvar: (outer_itvar * outer_stride + outer_start)}) + sum_rng = (outer_itvar, 0, sympy.ceiling((outer_stop - outer_start) / outer_stride)) + loop_executions = sympy.Sum(exec_repl, sum_rng) + starting_execs = loop_executions.doit() + starting_dynamic = region.dynamic_executions + else: + starting_execs = 0 + starting_dynamic = True + + # Propagate the number of executions. + self._propagate_in_cfg(region, pipeline_results[ControlFlowBlockReachability.__name__][region.cfg_id], + starting_execs, starting_dynamic) diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index b852b798b1..b89a09b196 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -62,6 +62,7 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: if oe.dst is merge_block: # Empty branch. branch.add_state('noop') + graph.remove_edge(oe) continue branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) @@ -87,7 +88,10 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: region.remove_node(dummy_exit) n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) - return n_cond_regions_post - n_cond_regions_pre + lifted = n_cond_regions_post - n_cond_regions_pre + if lifted: + sdfg.root_sdfg.using_experimental_blocks = True + return lifted def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: lifted_loops = 0 diff --git a/tests/state_propagation_test.py b/tests/state_propagation_test.py index 226775a0e7..2984a7707a 100644 --- a/tests/state_propagation_test.py +++ b/tests/state_propagation_test.py @@ -1,10 +1,12 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import pytest from dace.dtypes import Language from dace.properties import CodeProperty, CodeBlock from dace.sdfg.sdfg import InterstateEdge import dace from dace.sdfg.propagation import propagate_states +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising def state_check_executions(state, expected, expected_dynamic=False): @@ -16,7 +18,8 @@ def state_check_executions(state, expected, expected_dynamic=False): raise RuntimeError('Expected static executions, got dynamic') -def test_conditional_fake_merge(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_conditional_fake_merge(with_regions): sdfg = dace.SDFG('fake_merge') state_init = sdfg.add_state('init') @@ -40,13 +43,17 @@ def test_conditional_fake_merge(): sdfg.add_edge(state_c, state_e, InterstateEdge(condition=CodeProperty.from_string('not (j < 10)', language=Language.Python))) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) state_check_executions(state_d, 1, True) state_check_executions(state_e, 1, True) -def test_conditional_full_merge(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_conditional_full_merge(with_regions): sdfg = dace.SDFG('conditional_full_merge') sdfg.add_scalar('a', dace.int32) @@ -71,6 +78,9 @@ def test_conditional_full_merge(): sdfg.add_edge(r_branch, if_merge_2, dace.InterstateEdge()) sdfg.add_edge(if_merge_2, if_merge_1, dace.InterstateEdge()) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. @@ -92,7 +102,8 @@ def test_conditional_full_merge(): state_check_executions(if_merge_1, 1) -def test_while_inside_for(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_while_inside_for(with_regions): sdfg = dace.SDFG('while_inside_for') sdfg.add_symbol('i', dace.int32) @@ -116,13 +127,19 @@ def test_while_inside_for(): sdfg.add_edge(guard_2, loop_2, dace.InterstateEdge(condition=CodeBlock('j < 20'))) sdfg.add_edge(loop_2, guard_2, dace.InterstateEdge()) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # Check the for loop guard, `i in range(20)`. - state_check_executions(guard_1, 21) + if with_regions: + state_check_executions(guard_1, 20) + else: + state_check_executions(guard_1, 21) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. @@ -136,7 +153,8 @@ def test_while_inside_for(): state_check_executions(loop_2, 0, expected_dynamic=True) -def test_for_with_nested_full_merge_branch(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_for_with_nested_full_merge_branch(with_regions): sdfg = dace.SDFG('for_full_merge') sdfg.add_symbol('i', dace.int32) @@ -171,13 +189,19 @@ def test_for_with_nested_full_merge_branch(): sdfg.add_edge(r_branch, if_merge, dace.InterstateEdge()) sdfg.add_edge(if_merge, guard_1, dace.InterstateEdge(assignments={'i': 'i + 1'})) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # For loop, check loop guard, `for i in range(20)`. - state_check_executions(guard_1, 21) + if with_regions: + state_check_executions(guard_1, 20) + else: + state_check_executions(guard_1, 21) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. @@ -190,7 +214,8 @@ def test_for_with_nested_full_merge_branch(): state_check_executions(if_merge, 20) -def test_for_inside_branch(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_for_inside_branch(with_regions): sdfg = dace.SDFG('for_in_branch') state_init = sdfg.add_state('init') @@ -218,15 +243,22 @@ def test_for_inside_branch(): 'j': 'j + 1', })) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) state_check_executions(branch_guard, 1, False) - state_check_executions(loop_guard, 11, True) + if with_regions: + state_check_executions(loop_guard, 10, True) + else: + state_check_executions(loop_guard, 11, True) state_check_executions(loop_state, 10, True) state_check_executions(branch_merge, 1, False) -def test_full_merge_inside_loop(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_full_merge_inside_loop(with_regions): sdfg = dace.SDFG('full_merge_inside_loop') state_init = sdfg.add_state('init') @@ -256,16 +288,23 @@ def test_full_merge_inside_loop(): 'i': 'i + 1', })) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) - state_check_executions(loop_guard, 11, False) + if with_regions: + state_check_executions(loop_guard, 10, False) + else: + state_check_executions(loop_guard, 11, False) state_check_executions(branch_guard, 10, False) state_check_executions(branch_state, 10, True) state_check_executions(branch_merge, 10, False) state_check_executions(loop_end, 1, False) -def test_while_with_nested_full_merge_branch(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_while_with_nested_full_merge_branch(with_regions): sdfg = dace.SDFG('while_full_merge') sdfg.add_scalar('a', dace.int32) @@ -299,6 +338,9 @@ def test_while_with_nested_full_merge_branch(): sdfg.add_edge(r_branch, if_merge, dace.InterstateEdge()) sdfg.add_edge(if_merge, guard_1, dace.InterstateEdge()) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. @@ -318,7 +360,8 @@ def test_while_with_nested_full_merge_branch(): state_check_executions(if_merge, 0, expected_dynamic=True) -def test_3_fold_nested_loop_with_symbolic_bounds(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_3_fold_nested_loop_with_symbolic_bounds(with_regions): N = dace.symbol('N') M = dace.symbol('M') K = dace.symbol('K') @@ -355,34 +398,47 @@ def test_3_fold_nested_loop_with_symbolic_bounds(): sdfg.add_edge(guard_3, loop_3, dace.InterstateEdge(condition=CodeBlock('k < K'))) sdfg.add_edge(loop_3, guard_3, dace.InterstateEdge(assignments={'k': 'k + 1'})) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # 1st level loop, check loop guard, `for i in range(N)`. - state_check_executions(guard_1, N + 1) + if with_regions: + state_check_executions(guard_1, N) + else: + state_check_executions(guard_1, N + 1) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. state_check_executions(loop_1, N) # 2nd level nested loop, check loog guard, `for j in range(M)`. - state_check_executions(guard_2, M * N + N) + if with_regions: + state_check_executions(guard_2, M * N) + else: + state_check_executions(guard_2, M * N + N) # Check loop-end branch. state_check_executions(end_2, N) # Check inside the loop. state_check_executions(loop_2, M * N) # 3rd level nested loop, check loop guard, `for k in range(K)`. - state_check_executions(guard_3, M * N * K + M * N) + if with_regions: + state_check_executions(guard_3, M * N * K) + else: + state_check_executions(guard_3, M * N * K + M * N) # Check loop-end branch. state_check_executions(end_3, M * N) # Check inside the loop. state_check_executions(loop_3, M * N * K) -def test_3_fold_nested_loop(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_3_fold_nested_loop(with_regions): sdfg = dace.SDFG('nest_3') sdfg.add_symbol('i', dace.int32) @@ -415,27 +471,40 @@ def test_3_fold_nested_loop(): sdfg.add_edge(guard_3, loop_3, dace.InterstateEdge(condition=CodeBlock('k < j'))) sdfg.add_edge(loop_3, guard_3, dace.InterstateEdge(assignments={'k': 'k + 1'})) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # 1st level loop, check loop guard, `for i in range(20)`. - state_check_executions(guard_1, 21) + if with_regions: + state_check_executions(guard_1, 20) + else: + # When using a state-machine-style loop, the guard is executed N+1 times for N loop iterations. + state_check_executions(guard_1, 21) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. state_check_executions(loop_1, 20) # 2nd level nested loop, check loog guard, `for j in range(i, 20)`. - state_check_executions(guard_2, 230) + if with_regions: + state_check_executions(guard_2, 210) + else: + state_check_executions(guard_2, 230) # Check loop-end branch. state_check_executions(end_2, 20) # Check inside the loop. state_check_executions(loop_2, 210) # 3rd level nested loop, check loop guard, `for k in range(i, j)`. - state_check_executions(guard_3, 1540) + if with_regions: + state_check_executions(guard_3, 1330) + else: + state_check_executions(guard_3, 1540) # Check loop-end branch. state_check_executions(end_3, 210) # Check inside the loop. @@ -443,12 +512,21 @@ def test_3_fold_nested_loop(): if __name__ == "__main__": - test_3_fold_nested_loop() - test_3_fold_nested_loop_with_symbolic_bounds() - test_while_with_nested_full_merge_branch() - test_for_with_nested_full_merge_branch() - test_for_inside_branch() - test_while_inside_for() - test_conditional_full_merge() - test_conditional_fake_merge() - test_full_merge_inside_loop() + test_3_fold_nested_loop(False) + test_3_fold_nested_loop_with_symbolic_bounds(False) + test_while_with_nested_full_merge_branch(False) + test_for_with_nested_full_merge_branch(False) + test_for_inside_branch(False) + test_while_inside_for(False) + test_conditional_full_merge(False) + test_conditional_fake_merge(False) + test_full_merge_inside_loop(False) + test_3_fold_nested_loop(True) + test_3_fold_nested_loop_with_symbolic_bounds(True) + test_while_with_nested_full_merge_branch(True) + test_for_with_nested_full_merge_branch(True) + test_for_inside_branch(True) + test_while_inside_for(True) + test_conditional_full_merge(True) + test_conditional_fake_merge(True) + test_full_merge_inside_loop(True) From f63f75df075c5296935258b316c2adb0dade298b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 29 Oct 2024 11:52:47 +0100 Subject: [PATCH 083/108] Fix w/d test inlining --- tests/sdfg/work_depth_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index dd1c3eb518..39d13a1380 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -227,6 +227,8 @@ def test_work_depth(test_name): inliner.no_inline_function_call_regions = False inliner.no_inline_named_regions = False inliner.apply_pass(sdfg, {}) + for sd in sdfg.all_sdfgs_recursive(): + sd.using_experimental_blocks = False analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] @@ -283,6 +285,8 @@ def test_avg_par(test_name: str): inliner.no_inline_function_call_regions = False inliner.no_inline_named_regions = False inliner.apply_pass(sdfg, {}) + for sd in sdfg.all_sdfgs_recursive(): + sd.using_experimental_blocks = False analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] From 48c7cb4bd961cfc5b669c7ccebc4b29570909f16 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 29 Oct 2024 18:21:25 +0100 Subject: [PATCH 084/108] Fix to block fusion --- dace/transformation/interstate/block_fusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py index 71736fc269..37da066be2 100644 --- a/dace/transformation/interstate/block_fusion.py +++ b/dace/transformation/interstate/block_fusion.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from dace.sdfg import utils as sdutil -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, SDFGState +from dace.sdfg.state import AbstractControlFlowRegion, ControlFlowBlock, ControlFlowRegion, SDFGState from dace.transformation import transformation @@ -53,6 +53,9 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if out_edges[0].data.assignments: if not in_edges: return False + # If the first block is a control flow region, no absorbtion is possible. + if isinstance(self.first_block, AbstractControlFlowRegion): + return False # Fail if symbol is set before the block to fuse new_assignments = set(out_edges[0].data.assignments.keys()) if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): From a0e2c59e13991be0029ccd6ddff8a8170a65b082 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 29 Oct 2024 18:55:43 +0100 Subject: [PATCH 085/108] Derped a test.. --- tests/sdfg/work_depth_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 39d13a1380..159e2eb8b6 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -196,7 +196,7 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_8') * N, 2 * sp.Symbol('num_execs_0_8'))), 'break_for_loop': (break_for_loop, (N**2, N)), 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_7') * N, sp.Symbol('num_execs_0_7'))), 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), From 0828fa174eda9ba4a0f1f31b575a2602e8ec5b05 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 5 Nov 2024 13:14:09 +0100 Subject: [PATCH 086/108] Fix inlining --- dace/sdfg/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 5bdac9a220..8015c6dd4d 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1270,8 +1270,7 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> progress = True pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) - or not isinstance(u, SDFGState)): + if u in skip_nodes or v in skip_nodes: continue if isinstance(u, SDFGState) and isinstance(v, SDFGState): From 3094fa1dd6469029964425f58d5278d186f33692 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 12 Nov 2024 09:07:26 +0100 Subject: [PATCH 087/108] Update SDFV --- dace/viewer/webclient | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/viewer/webclient b/dace/viewer/webclient index c6b8fe4fd2..f8f3e9d352 160000 --- a/dace/viewer/webclient +++ b/dace/viewer/webclient @@ -1 +1 @@ -Subproject commit c6b8fe4fd2c3616b0480ead4c24d8012b91a31fd +Subproject commit f8f3e9d352ad28794ecddf94fbb04d888083f6fa From 38fcbafddce8e69a673c82dd7b9314c14c6df003 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 12 Nov 2024 15:39:49 +0100 Subject: [PATCH 088/108] Endless loop in constant prop fix --- .../passes/constant_propagation.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 5350106672..8149bc3a60 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -255,7 +255,7 @@ def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arra else: # No else branch is present, so it is possible that no branch is executed. In this case the out constants # are the intersection between the in constants and the post constants. - out_consts = in_consts + out_consts = in_consts.copy() for k, v in post_consts.items(): if k not in out_consts: out_consts[k] = _UnknownValue @@ -364,14 +364,6 @@ def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str] if reassignments and (used_in_assignments - reassignments): assignments[aname] = _UnknownValue - if isinstance(block, LoopRegion): - # Any constants before a loop that may be overwritten inside the loop cannot be assumed as constants - # for the loop itself. - assigned_in_loop = self._assignments_in_loop(block) - for k in assignments.keys(): - if k in assigned_in_loop: - assignments[k] = _UnknownValue - if block not in in_const_dict: in_const_dict[block] = {} if assignments: @@ -385,9 +377,9 @@ def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str] post_const_dict, out_const_dict) else: # Simple case, no change in constants through this block (states and other basic blocks). - pre_const_dict[block] = in_const_dict[block] - post_const_dict[block] = in_const_dict[block] - out_const_dict[block] = in_const_dict[block] + pre_const_dict[block] = in_const_dict[block].copy() + post_const_dict[block] = in_const_dict[block].copy() + out_const_dict[block] = in_const_dict[block].copy() # For all sink nodes, compute the overlapping set of constants between them, making sure all constants in the # resulting intersection are actually constants (i.e., all blocks see the same constant value for them). This From 567d3076a56a6ce527d53c9ab31c163e9d41c936 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 12 Nov 2024 16:51:50 +0100 Subject: [PATCH 089/108] Fix propagation --- dace/transformation/passes/constant_propagation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 8149bc3a60..19e17066fa 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -118,6 +118,10 @@ def _add_nested_datanames(name: str, desc: data.Structure): k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } + out_mapping = { + k: v + for k, v in out_consts[block].items() if v is not _UnknownValue and k not in multivalue_desc_symbols + } if mapping: # Update replaced symbols for later replacements @@ -129,9 +133,10 @@ def _add_nested_datanames(name: str, desc: data.Structure): elif isinstance(block, AbstractControlFlowRegion): block.replace_dict(mapping, replace_in_graph=False, replace_keys=False) + if out_mapping: # Replace in outgoing edges as well for e in block.parent_graph.out_edges(block): - e.data.replace_dict(mapping, replace_keys=False) + e.data.replace_dict(out_mapping, replace_keys=False) if isinstance(block, LoopRegion): if block in post_consts and post_consts[block] is not None: From 61ac6ee43a36cdd2d718310e9d4e9ad789fb7a87 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 09:32:51 +0100 Subject: [PATCH 090/108] More fixes --- dace/transformation/passes/constant_propagation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 19e17066fa..d95b72ab28 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -369,6 +369,14 @@ def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str] if reassignments and (used_in_assignments - reassignments): assignments[aname] = _UnknownValue + if isinstance(block, LoopRegion): + # Any constants before a loop that may be overwritten inside the loop cannot be assumed as constants + # for the loop itself. + assigned_in_loop = self._assignments_in_loop(block) + for k in assignments.keys(): + if k in assigned_in_loop: + assignments[k] = _UnknownValue + if block not in in_const_dict: in_const_dict[block] = {} if assignments: From 8c488de581c82db0587c498553994a7053b6a0da Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 09:56:27 +0100 Subject: [PATCH 091/108] Update gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 7209622916..03c801a68f 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,8 @@ src.VC.VC.opendb # DaCe .dacecache/ +# Ignore dacecache if added as a symlink +.dacecache out.sdfg *.out results.log From 7e4bc3da7154164ca5b1264a86dbd1cf921e6260 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 12:16:30 +0100 Subject: [PATCH 092/108] Fix loop symbol type inference and loop to map --- dace/codegen/targets/framecode.py | 17 ++++++++++------- dace/transformation/interstate/loop_to_map.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index c0e08cfba7..11a198f119 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -936,13 +936,16 @@ def generate_code(self, if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: if not cfr.loop_variable in interstate_symbols: - l_end = loop_analysis.get_loop_end(cfr) - l_start = loop_analysis.get_init_assignment(cfr) - l_step = loop_analysis.get_loop_stride(cfr) - sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), - infer_expr_type(l_step, global_symbols), - infer_expr_type(l_end, global_symbols)) - interstate_symbols[cfr.loop_variable] = sym_type + if cfr.loop_variable in global_symbols: + interstate_symbols[cfr.loop_variable] = global_symbols[cfr.loop_variable] + else: + l_end = loop_analysis.get_loop_end(cfr) + l_start = loop_analysis.get_init_assignment(cfr) + l_step = loop_analysis.get_loop_stride(cfr) + sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), + infer_expr_type(l_step, global_symbols), + infer_expr_type(l_end, global_symbols)) + interstate_symbols[cfr.loop_variable] = sym_type if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 55327af5fb..9f487f561a 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -6,7 +6,8 @@ import sympy as sp from typing import Dict, List, Set -from dace import data as dt, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import graph as gr, nodes from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil @@ -94,6 +95,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): if start is None or end is None or step is None or itervar is None: return False + sset = {} + sset.update(sdfg.symbols) + sset.update(sdfg.arrays) + t = dtypes.result_type_of(infer_expr_type(start, sset), infer_expr_type(step, sset), infer_expr_type(end, sset)) + # We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules + # except for sequential would be invalid. + if not t in dtypes.INTEGER_TYPES: + return False + # Loops containing break, continue, or returns may not be turned into a map. for blk in self.loop.all_control_flow_blocks(): if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)): From 40d4a125c941988a684be21404ffc9ebc440977a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 15:26:40 +0100 Subject: [PATCH 093/108] Fix traversal for defined symbols --- dace/sdfg/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 8015c6dd4d..46cdf1fe13 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1557,6 +1557,10 @@ def _tswds_cf_region( for _, b in region.branches: yield from _tswds_cf_region(sdfg, b, symbols, recursive) return + elif isinstance(region, LoopRegion): + # Add the own loop variable to the defined symbols, if present. + loop_syms = region.new_symbols(symbols) + symbols.update({k: v for k, v in loop_syms.items() if v is not None}) # Add symbols from inter-state edges along the state machine start_region = region.start_block From ef595b49db5b62548c9637266421c3290fedce80 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 3 Dec 2024 10:03:59 +0100 Subject: [PATCH 094/108] Skip FV3 pipeline until adapted to V2 --- .github/workflows/pyFV3-ci.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pyFV3-ci.yml b/.github/workflows/pyFV3-ci.yml index 852b887cdb..2f587e9894 100644 --- a/.github/workflows/pyFV3-ci.yml +++ b/.github/workflows/pyFV3-ci.yml @@ -1,12 +1,17 @@ name: NASA/NOAA pyFV3 repository build test +# Temporarily disabled for main, and instead applied to a specific DaCe v1 maintenance branch (v1/maintenance). Once +# the FV3 bridge has been adapted to DaCe v1, this will need to be reverted back to apply to main. on: push: - branches: [ main, ci-fix ] + #branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + #branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + #branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] defaults: run: From d27ec319f52221201183fbc5896f73d40571bcca Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 3 Dec 2024 10:26:42 +0100 Subject: [PATCH 095/108] Fix bug introduced through merge --- dace/sdfg/state.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a543b514e5..e45e1faa58 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2857,13 +2857,9 @@ def all_control_flow_regions(self, recursive=False, load_ext=False, node.load_external(block) yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext, parent_first=parent_first) - elif isinstance(block, ControlFlowRegion): + elif isinstance(block, AbstractControlFlowRegion): yield from block.all_control_flow_regions(recursive=recursive, load_ext=load_ext, parent_first=parent_first) - elif isinstance(block, ConditionalBlock): - for _, branch in block.branches: - yield from branch.all_control_flow_regions(recursive=recursive, load_ext=load_ext, - parent_first=parent_first) if not parent_first: yield self From 72508bd6d6272c5a80aab19167e85f688e81c1f0 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 9 Dec 2024 14:56:40 +0100 Subject: [PATCH 096/108] Update dace/transformation/interstate/block_fusion.py Co-authored-by: Tal Ben-Nun --- dace/transformation/interstate/block_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py index 37da066be2..813d7ad043 100644 --- a/dace/transformation/interstate/block_fusion.py +++ b/dace/transformation/interstate/block_fusion.py @@ -53,7 +53,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if out_edges[0].data.assignments: if not in_edges: return False - # If the first block is a control flow region, no absorbtion is possible. + # If the first block is a control flow region, no absorption is possible. if isinstance(self.first_block, AbstractControlFlowRegion): return False # Fail if symbol is set before the block to fuse From ce82732798237b2195997d630aa603cbef3efb24 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 9 Dec 2024 14:58:56 +0100 Subject: [PATCH 097/108] Address review comments --- dace/transformation/interstate/block_fusion.py | 7 +++---- tests/schedule_tree/schedule_test.py | 4 ---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py index 37da066be2..afc07a7709 100644 --- a/dace/transformation/interstate/block_fusion.py +++ b/dace/transformation/interstate/block_fusion.py @@ -7,11 +7,10 @@ @transformation.experimental_cfg_block_compatible class BlockFusion(transformation.MultiStateTransformation): - """ Implements the state-fusion transformation. + """ Implements the block-fusion transformation. - State-fusion takes two states that are connected through a single edge, - and fuses them into one state. If permissive, also applies if potential memory - access hazards are created. + Block-fusion takes two control flow blocks that are connected through a single edge, where either one or both + blocks are 'no-op' control flow blocks, and fuses them into one. """ first_block = transformation.PatternNode(ControlFlowBlock) diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 19f5e19cc6..295e5f6bce 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -5,10 +5,6 @@ from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree import numpy as np -from dace.sdfg.sdfg import InterstateEdge -from dace.sdfg.state import LoopRegion -from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising - def test_for_in_map_in_for(): From 3c13369292b51e1fcc3e8afbc9c86f287c1f9d12 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 9 Dec 2024 15:16:16 +0100 Subject: [PATCH 098/108] More comments --- dace/transformation/interstate/loop_unroll.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 66dbca1a83..a41dc95dc5 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -26,7 +26,7 @@ class LoopUnroll(xf.MultiStateTransformation): ) inline_iterations = Property(dtype=bool, default=True, - desc='Whether or not to inline individual iteration\'s CFGs after unrolling') + desc="Whether or not to inline individual iterations' CFGs after unrolling") @classmethod def expressions(cls): From c4e78d7e7308b6ae023d94c5cdb2b2e6f79acf73 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 9 Dec 2024 15:24:20 +0100 Subject: [PATCH 099/108] Add doc comments --- dace/sdfg/state.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index e45e1faa58..fd0d5e99f3 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2578,6 +2578,17 @@ def sdfg(self) -> 'SDFG': @make_properties class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, ControlFlowBlock, abc.ABC): + """ + Abstract superclass to represent all kinds of control flow regions in an SDFG. + This is consequently one of the three main classes of control flow graph nodes, which include `ControlFlowBlock`s, + `SDFGState`s, and nested `AbstractControlFlowRegion`s. An `AbstractControlFlowRegion` can further be either a region + that directly contains a control flow graph (`ControlFlowRegion`s and subclasses thereof), or something that acts + like and has the same utilities as a control flow region, including the same API, but is itself not directly a + single graph. An example of this is the `ConditionalBlock`, which acts as a single control flow region to the + outside, but contains multiple actual graphs (one per branch). As such, there are very few but important differences + between the subclasses of `AbstractControlFlowRegion`s, such as how traversals are performed, how many start blocks + there are, etc. + """ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optional['AbstractControlFlowRegion'] = None): @@ -3050,6 +3061,11 @@ def start_block(self, block_id): @make_properties class ControlFlowRegion(AbstractControlFlowRegion): + """ + A `ControlFlowRegion` represents a control flow graph node that itself contains a control flow graph. + This can be an arbitrary control flow graph, but may also be a specific type of control flow region with additional + semantics, such as a loop or a function call. + """ def __init__(self, label = '', sdfg = None, parent = None): super().__init__(label, sdfg, parent) From 3a2b34253fbd563ffd564cf78f7f2024834b76c5 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 9 Dec 2024 16:42:17 +0100 Subject: [PATCH 100/108] Address more comments --- dace/sdfg/state.py | 19 ++++--- dace/transformation/interstate/loop_to_map.py | 2 +- .../passes/analysis/analysis.py | 23 ++++---- .../passes/constant_propagation.py | 52 ++++++++++--------- .../passes/dead_state_elimination.py | 12 ++--- 5 files changed, 58 insertions(+), 50 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index fd0d5e99f3..1946b19c5b 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -726,11 +726,11 @@ def update_if_not_none(dic, update): defined_syms[str(sym)] = sym.dtype # Add inter-state symbols - if isinstance(sdfg.start_block, LoopRegion): + if isinstance(sdfg.start_block, AbstractControlFlowRegion): update_if_not_none(defined_syms, sdfg.start_block.new_symbols(defined_syms)) for edge in sdfg.all_interstate_edges(): update_if_not_none(defined_syms, edge.data.new_symbols(sdfg, defined_syms)) - if isinstance(edge.dst, LoopRegion): + if isinstance(edge.dst, AbstractControlFlowRegion): update_if_not_none(defined_syms, edge.dst.new_symbols(defined_syms)) # Add scope symbols all the way to the subgraph @@ -2722,6 +2722,12 @@ def inline(self) -> Tuple[bool, Any]: return False, None + def new_symbols(self, symbols: dict) -> Dict[str, dtypes.typeclass]: + """ + Returns a mapping between the symbol defined by this control flow region and its type, if it exists. + """ + return {} + ################################################################### # CFG API methods @@ -3306,9 +3312,6 @@ def _used_symbols_internal(self, return free_syms, defined_syms, used_before_assignment def new_symbols(self, symbols) -> Dict[str, dtypes.typeclass]: - """ - Returns a mapping between the symbol defined by this loop and its type, if it exists. - """ # Avoid cyclic import from dace.codegen.tools.type_inference import infer_expr_type from dace.transformation.passes.analysis import loop_analysis @@ -3402,11 +3405,7 @@ def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): branch.sdfg = self.sdfg def remove_branch(self, branch: ControlFlowRegion): - filtered_branches = [] - for c, b in self._branches: - if b is not branch: - filtered_branches.append((c, b)) - self._branches = filtered_branches + self._branches = [(c, b) for c, b in self._branches if b is not branch] def _used_symbols_internal(self, all_symbols: bool, diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 9f487f561a..230545eecf 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -492,7 +492,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Direct edges among source and sink access nodes must pass through a tasklet. # We first gather them and handle them later. - direct_edges: Set[gr.Edge[memlet.Memlet]] = set() + direct_edges: Set[gr.MultiConnectorEdge[memlet.Memlet]] = set() for n1 in source_nodes: if not isinstance(n1, nodes.AccessNode): continue diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 2bd90a4f22..ffd8a6134f 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -240,6 +240,19 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If access nodes were modified, reapply return modified & ppl.Modifies.AccessNodes + def _get_loop_region_readset(self, loop: LoopRegion, arrays: Set[str]) -> Set[str]: + readset = set() + exprs = { loop.loop_condition.as_string } + update_stmt = loop_analysis.get_update_assignment(loop) + init_stmt = loop_analysis.get_init_assignment(loop) + if update_stmt: + exprs.add(update_stmt) + if init_stmt: + exprs.add(init_stmt) + for expr in exprs: + readset |= symbolic.free_symbols_and_functions(expr) & arrays + return readset + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]: """ :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. @@ -263,15 +276,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str] if state.out_degree(anode) > 0: readset.add(anode.data) if isinstance(block, LoopRegion): - exprs = set([ block.loop_condition.as_string ]) - update_stmt = loop_analysis.get_update_assignment(block) - init_stmt = loop_analysis.get_init_assignment(block) - if update_stmt: - exprs.add(update_stmt) - if init_stmt: - exprs.add(init_stmt) - for expr in exprs: - readset |= symbolic.free_symbols_and_functions(expr) & arrays + readset |= self._get_loop_region_readset(block, arrays) elif isinstance(block, ConditionalBlock): for cond, _ in block.branches: if cond is not None: diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index d95b72ab28..92183451a2 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -82,7 +82,6 @@ def _add_nested_datanames(name: str, desc: data.Structure): if isinstance(v, data.Structure): _add_nested_datanames(f'{name}.{k}', v) elif isinstance(v, data.ContainerArray): - # TODO: How are we handling this? pass arrays.add(f'{name}.{k}') @@ -91,11 +90,11 @@ def _add_nested_datanames(name: str, desc: data.Structure): _add_nested_datanames(name, desc) # Trace all constants and symbols through blocks - in_consts: BlockConstsT = { sdfg: initial_symbols } - pre_consts: BlockConstsT = {} - post_consts: BlockConstsT = {} - out_consts: BlockConstsT = {} - self._collect_constants_for_region(sdfg, arrays, in_consts, pre_consts, post_consts, out_consts) + in_constants: BlockConstsT = { sdfg: initial_symbols } + pre_constants: BlockConstsT = {} + post_constants: BlockConstsT = {} + out_constants: BlockConstsT = {} + self._collect_constants_for_region(sdfg, arrays, in_constants, pre_constants, post_constants, out_constants) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} @@ -103,11 +102,11 @@ def _add_nested_datanames(name: str, desc: data.Structure): # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated - desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_consts) + desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_constants) # Replace constants per state - for block, mapping in optional_progressbar(in_consts.items(), 'Propagating constants', n=len(in_consts), - progress=self.progress): + for block, mapping in optional_progressbar(in_constants.items(), 'Propagating constants', + n=len(in_constants), progress=self.progress): if block is sdfg: continue @@ -120,7 +119,8 @@ def _add_nested_datanames(name: str, desc: data.Structure): } out_mapping = { k: v - for k, v in out_consts[block].items() if v is not _UnknownValue and k not in multivalue_desc_symbols + for k, v in out_constants[block].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols } if mapping: @@ -139,20 +139,7 @@ def _add_nested_datanames(name: str, desc: data.Structure): e.data.replace_dict(out_mapping, replace_keys=False) if isinstance(block, LoopRegion): - if block in post_consts and post_consts[block] is not None: - if block.update_statement is not None and (block.inverted and block.update_before_condition or - not block.inverted): - # Replace the RHS of the update experssion - post_mapping = { - k: v - for k, v in post_consts[block].items() - if v is not _UnknownValue and k not in multivalue_desc_symbols - } - update_stmt = block.update_statement - updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] - for update in updates: - astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update) - block.update_statement.code = updates + self._propagate_loop(block, post_constants, multivalue_desc_symbols) # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} @@ -205,6 +192,23 @@ def _add_nested_datanames(name: str, desc: data.Structure): def report(self, pass_retval: Set[str]) -> str: return f'Propagated {len(pass_retval)} constants.' + def _propagate_loop(self, loop: LoopRegion, post_constants: BlockConstsT, + multivalue_desc_symbols: Set[str]) -> None: + if loop in post_constants and post_constants[loop] is not None: + if loop.update_statement is not None and (loop.inverted and loop.update_before_condition or + not loop.inverted): + # Replace the RHS of the update experssion + post_mapping = { + k: v + for k, v in post_constants[loop].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols + } + update_stmt = loop.update_statement + updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + for update in updates: + astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update) + loop.update_statement.code = updates + def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arrays: Set[str], in_const_dict: BlockConstsT, pre_const_dict: BlockConstsT, post_const_dict: BlockConstsT, out_const_dict: BlockConstsT) -> None: diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 23f2a785f5..e9c622d128 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -166,7 +166,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch') break # If an unconditional branch is found, ignore all other branches that follow this one. - if cond.as_string.strip() == '1' or self._is_truthy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + if cond.as_string.strip() == '1' or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): unconditional = branch break if unconditional is not None: @@ -177,7 +177,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, else: # Check if any branches are certainly never taken. for cond, branch in block.branches: - if cond is not None and self._is_falsy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + if cond is not None and self._is_definitely_false(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): dead_branches.append([cond, branch]) return dead_branches @@ -195,9 +195,9 @@ def is_definitely_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return True # Evaluate condition - return self._is_truthy(edge.condition_sympy(), sdfg) + return self._is_definitely_true(edge.condition_sympy(), sdfg) - def _is_truthy(self, cond: sp.Basic, sdfg: SDFG) -> bool: + def _is_definitely_true(self, cond: sp.Basic, sdfg: SDFG) -> bool: if cond == True or cond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False): return True @@ -215,9 +215,9 @@ def is_definitely_not_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return False # Evaluate condition - return self._is_falsy(edge.condition_sympy(), sdfg) + return self._is_definitely_false(edge.condition_sympy(), sdfg) - def _is_falsy(self, cond: sp.Basic, sdfg: SDFG) -> bool: + def _is_definitely_false(self, cond: sp.Basic, sdfg: SDFG) -> bool: if cond == False or cond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False): return True From 2f7d6aa30004e4d065bc157ab12efe53aa6fa185 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 10 Dec 2024 11:41:50 +0100 Subject: [PATCH 101/108] Address more review comments --- dace/sdfg/state.py | 6 ++ dace/sdfg/utils.py | 16 ++--- dace/sdfg/validation.py | 6 +- dace/transformation/helpers.py | 2 +- .../interstate/fpga_transform_state.py | 7 +-- .../interstate/multistate_inline.py | 7 +-- dace/transformation/passes/fusion_inline.py | 16 ++--- .../prune_empty_conditional_branches.py | 3 + dace/transformation/subgraph/composite.py | 2 +- .../writeset_underapproximation_test.py | 58 +++---------------- tests/sdfg/work_depth_test.py | 17 +----- 11 files changed, 47 insertions(+), 93 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1946b19c5b..61c4a1f727 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1225,6 +1225,9 @@ def nodes(self): def edges(self): return [] + def sub_regions(self) -> List['AbstractControlFlowRegion']: + return [] + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): """ Sets the default source line information to be lineinfo, or None to @@ -3389,6 +3392,9 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optio super().__init__(label, sdfg, parent) self._branches = [] + def sub_regions(self): + return [b for _, b in self.branches] + def __str__(self): return self._label diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 46cdf1fe13..7cd438be88 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1305,13 +1305,13 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractControlFlowRegion]]] = None, - blacklist: Optional[List[Type[AbstractControlFlowRegion]]] = None, + ignore_region_types: Optional[List[Type[AbstractControlFlowRegion]]] = None, progress: bool = None) -> int: if types: blocks = [n for n, _ in sdfg.all_nodes_recursive() if type(n) in types] - elif blacklist: + elif ignore_region_types: blocks = [n for n, _ in sdfg.all_nodes_recursive() - if isinstance(n, AbstractControlFlowRegion) and type(n) not in blacklist] + if isinstance(n, AbstractControlFlowRegion) and type(n) not in ignore_region_types] else: blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, AbstractControlFlowRegion)] count = 0 @@ -1327,6 +1327,8 @@ def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractCo if block.inline()[0]: count += 1 + sdfg.reset_cfg_list() + return count @@ -1557,10 +1559,10 @@ def _tswds_cf_region( for _, b in region.branches: yield from _tswds_cf_region(sdfg, b, symbols, recursive) return - elif isinstance(region, LoopRegion): - # Add the own loop variable to the defined symbols, if present. - loop_syms = region.new_symbols(symbols) - symbols.update({k: v for k, v in loop_syms.items() if v is not None}) + + # Add the own loop variable to the defined symbols, if present. + loop_syms = region.new_symbols(symbols) + symbols.update({k: v for k, v in loop_syms.items() if v is not None}) # Add symbols from inter-state edges along the state machine start_region = region.start_block diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 603c9cb314..b030d85466 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -81,7 +81,7 @@ def validate_control_flow_region(sdfg: 'SDFG', if r is not None: validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.src, ControlFlowRegion): - lsyms = copy.deepcopy(symbols) + lsyms = copy.copy(symbols) if isinstance(edge.src, LoopRegion) and not edge.src.loop_variable in lsyms: lsyms[edge.src.loop_variable] = None validate_control_flow_region(sdfg, edge.src, initialized_transients, lsyms, references, **context) @@ -147,7 +147,7 @@ def validate_control_flow_region(sdfg: 'SDFG', if r is not None: validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.dst, ControlFlowRegion): - lsyms = copy.deepcopy(symbols) + lsyms = copy.copy(symbols) if isinstance(edge.dst, LoopRegion) and not edge.dst.loop_variable in lsyms: lsyms[edge.dst.loop_variable] = None validate_control_flow_region(sdfg, edge.dst, initialized_transients, lsyms, references, **context) @@ -163,7 +163,7 @@ def validate_control_flow_region(sdfg: 'SDFG', if r is not None: validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(start_block, ControlFlowRegion): - lsyms = copy.deepcopy(symbols) + lsyms = copy.copy(symbols) if isinstance(start_block, LoopRegion) and not start_block.loop_variable in lsyms: lsyms[start_block.loop_variable] = None validate_control_flow_region(sdfg, start_block, initialized_transients, lsyms, references, **context) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 5aa7c7a347..b703dd402d 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -145,7 +145,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args` pass for b in all_blocks: - if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '': + if isinstance(b, LoopRegion) and b.loop_variable: defined_symbols.add(b.loop_variable) if b.loop_variable not in sdfg.symbols: if b.init_statement: diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index 287b725abb..602cca6df1 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -98,13 +98,12 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() - stack = [] - for node, pGraph in state.all_nodes_recursive(): + for node, node_parent_graph in state.all_nodes_recursive(): if isinstance(node, dace.sdfg.nodes.AccessNode): - for e in pGraph.in_edges(node): + for e in node_parent_graph.in_edges(node): if e.data.wcr is not None: - trace = dace.sdfg.trace_nested_access(node, pGraph, pGraph.sdfg) + trace = dace.sdfg.trace_nested_access(node, node_parent_graph, node_parent_graph.sdfg) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 80afdf2f36..45de07a5a6 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -145,12 +145,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): from dace.transformation.passes.fusion_inline import InlineControlFlowRegions from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising - inline_pass = InlineControlFlowRegions() - inline_pass.no_inline_conditional = False - inline_pass.no_inline_named_regions = False - inline_pass.no_inline_function_call_regions = False - inline_pass.no_inline_loops = False - inline_pass.apply_pass(nsdfg, {}) + sdutil.inline_control_flow_regions(nsdfg) # After inlining, try to lift out control flow again, essentially preserving all control flow that can be # preserved while removing the return blocks. ControlFlowRaising().apply_pass(nsdfg, {}) diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 277403f2e7..13583998ac 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -126,21 +126,21 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: :return: The total number of states fused, or None if did not apply. """ - blacklist = [] + ignore_region_types = [] if self.no_inline_loops: - blacklist.append(LoopRegion) + ignore_region_types.append(LoopRegion) if self.no_inline_conditional: - blacklist.append(ConditionalBlock) + ignore_region_types.append(ConditionalBlock) if self.no_inline_named_regions: - blacklist.append(NamedRegion) + ignore_region_types.append(NamedRegion) if self.no_inline_function_call_regions: - blacklist.append(FunctionCallRegion) - if len(blacklist) < 1: - blacklist = None + ignore_region_types.append(FunctionCallRegion) + if len(ignore_region_types) < 1: + ignore_region_types = None inlined = 0 while True: - inlined_in_iteration = inline_control_flow_regions(sdfg, None, blacklist, self.progress) + inlined_in_iteration = inline_control_flow_regions(sdfg, None, ignore_region_types, self.progress) if inlined_in_iteration < 1: break inlined += inlined_in_iteration diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py index 29a400d5d1..1f44008351 100644 --- a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -11,6 +11,9 @@ @properties.make_properties @transformation.experimental_cfg_block_compatible class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass): + """ + Prunes empty (or no-op) conditional branches from conditional blocks. + """ CATEGORY: str = 'Simplification' diff --git a/dace/transformation/subgraph/composite.py b/dace/transformation/subgraph/composite.py index c9f26f8670..a7904e5108 100644 --- a/dace/transformation/subgraph/composite.py +++ b/dace/transformation/subgraph/composite.py @@ -71,7 +71,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: break if not par_graph_copy: return False - graph_copy = par_graph_copy.nodes()[graph.block_id] + graph_copy = par_graph_copy.node(graph.block_id) subgraph_copy = SubgraphView(graph_copy, [graph_copy.nodes()[i] for i in graph_indices]) expansion.cfg_id = par_graph_copy.cfg_id diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 0e6f9d4fb4..b92aee12da 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -3,9 +3,9 @@ from typing import Dict import dace from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDict +from dace.sdfg.utils import inline_control_flow_regions from dace.subsets import Range from dace.transformation.pass_pipeline import Pipeline -from dace.transformation.passes.fusion_inline import InlineControlFlowRegions N = dace.symbol("N") M = dace.symbol("M") @@ -309,12 +309,7 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -341,12 +336,7 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -474,12 +464,7 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -517,12 +502,7 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -558,12 +538,7 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -600,12 +575,7 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) result: Dict[int, UnderapproximateWritesDict] = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -874,12 +844,7 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -912,12 +877,7 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 159e2eb8b6..7c86f454c6 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -13,13 +13,12 @@ import sympy as sp import numpy as np +from dace.sdfg.utils import inline_control_flow_regions from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion from pytest import raises -from dace.transformation.passes.fusion_inline import InlineControlFlowRegions - N = dc.symbol('N') M = dc.symbol('M') K = dc.symbol('K') @@ -221,12 +220,7 @@ def test_work_depth(test_name): sdfg.apply_transformations(MapExpansion) # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) for sd in sdfg.all_sdfgs_recursive(): sd.using_experimental_blocks = False @@ -279,12 +273,7 @@ def test_avg_par(test_name: str): sdfg.apply_transformations(MapExpansion) # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg) for sd in sdfg.all_sdfgs_recursive(): sd.using_experimental_blocks = False From a4692869f41919fb93eee872a685b87b3849a3a7 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 10 Dec 2024 12:59:20 +0100 Subject: [PATCH 102/108] Inlining fix --- dace/sdfg/state.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 61c4a1f727..40aa866d03 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2706,15 +2706,6 @@ def inline(self) -> Tuple[bool, Any]: for node in to_connect: parent.add_edge(node, end_state, dace.InterstateEdge()) - else: - # TODO: Move this to dead state elimination. - dead_blocks = [succ for succ in parent.successors(self) if parent.in_degree(succ) == 1] - while dead_blocks: - layer = list(dead_blocks) - dead_blocks.clear() - for u in layer: - dead_blocks.extend([succ for succ in parent.successors(u) if parent.in_degree(succ) == 1]) - parent.remove_node(u) # Remove the original control flow region (self) from the parent graph. parent.remove_node(self) @@ -3515,6 +3506,7 @@ def inline(self) -> Tuple[bool, Any]: parent.add_node(region) parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) parent.add_edge(region, end_state, InterstateEdge()) + region.inline() if full_cond_expression is not None: negative_full_cond = astutils.negate_expr(full_cond_expression) negative_cond = CodeBlock([negative_full_cond]) @@ -3524,7 +3516,8 @@ def inline(self) -> Tuple[bool, Any]: if else_branch is not None: parent.add_node(else_branch) parent.add_edge(guard_state, else_branch, InterstateEdge(condition=negative_cond)) - parent.add_edge(region, end_state, InterstateEdge()) + parent.add_edge(else_branch, end_state, InterstateEdge()) + else_branch.inline() else: parent.add_edge(guard_state, end_state, InterstateEdge(condition=negative_cond)) From a8546ab66b05d887a3403f9a8f3f18a5e6540c60 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 11 Dec 2024 13:45:24 +0100 Subject: [PATCH 103/108] Fixes to control flow raising and codegen --- dace/codegen/control_flow.py | 42 +++-- dace/codegen/targets/framecode.py | 14 +- dace/sdfg/infer_types.py | 1 - dace/sdfg/sdfg.py | 6 +- dace/sdfg/state.py | 176 +++++++++++------- dace/sdfg/utils.py | 9 +- .../interstate/gpu_transform_sdfg.py | 15 +- .../simplification/control_flow_raising.py | 119 ++++++++++-- .../control_flow_raising_test.py | 44 ++--- 9 files changed, 276 insertions(+), 150 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index fdba40526d..31988ba700 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various classes to facilitate the code generation of structured control flow elements (e.g., ``for``, ``if``, ``while``) from state machines in SDFGs. @@ -200,7 +200,10 @@ class BreakCFBlock(ControlFlow): block: BreakBlock def as_cpp(self, codegen, symbols) -> str: - return 'break;\n' + cfg = self.block.parent_graph + expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label) + expr += 'break;\n' + return expr @property def first_block(self) -> BreakBlock: @@ -214,7 +217,10 @@ class ContinueCFBlock(ControlFlow): block: ContinueBlock def as_cpp(self, codegen, symbols) -> str: - return 'continue;\n' + cfg = self.block.parent_graph + expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label) + expr += 'continue;\n' + return expr @property def first_block(self) -> ContinueBlock: @@ -228,7 +234,10 @@ class ReturnCFBlock(ControlFlow): block: ReturnBlock def as_cpp(self, codegen, symbols) -> str: - return 'return;\n' + cfg = self.block.parent_graph + expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label) + expr += 'return;\n' + return expr @property def first_block(self) -> ReturnBlock: @@ -316,7 +325,13 @@ def as_cpp(self, codegen, symbols) -> str: # One unconditional edge if (len(out_edges) == 1 and out_edges[0].data.is_unconditional()): continue - expr += f'goto __state_exit_{sdfg.cfg_id};\n' + if self.region: + expr += f'goto __state_exit_{self.region.cfg_id};\n' + else: + expr += f'goto __state_exit_{sdfg.cfg_id};\n' + + if self.region and not isinstance(self.region, SDFG): + expr += f'__state_exit_{self.region.cfg_id}:;\n' return expr @@ -575,6 +590,8 @@ def as_cpp(self, codegen, symbols) -> str: expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) expr += '\n}\n' + expr += f'__state_exit_{self.loop.cfg_id}:;\n' + return expr @property @@ -1022,21 +1039,16 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion, start: Optional[ControlFlowBlock] = None, stop: Optional[ControlFlowBlock] = None, generate_children_of: Optional[ControlFlowBlock] = None, - branch_merges: Optional[Dict[ControlFlowBlock, - ControlFlowBlock]] = None, ptree: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None, visited: Optional[Set[ControlFlowBlock]] = None): - if branch_merges is None: - branch_merges = cfg_analysis.branch_merges(cfg) - if ptree is None: ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False) start = start if start is not None else cfg.start_block - def make_empty_block(): + def make_empty_block(region): return GeneralBlock(dispatch_state, parent_block, - last_block=False, region=None, elements=[], gotos_to_ignore=[], + last_block=False, region=region, elements=[], gotos_to_ignore=[], gotos_to_break=[], gotos_to_continue=[], assignments_to_ignore=[], sequential=True) # Traverse states in custom order @@ -1063,18 +1075,18 @@ def make_empty_block(): cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, []) for cond, branch in node.branches: if branch is not None: - body = make_empty_block() + body = make_empty_block(branch) body.parent = cfg_block _structured_control_flow_traversal_with_regions(branch, dispatch_state, body) cfg_block.branch_bodies.append((cond, body)) elif isinstance(node, ControlFlowRegion): if isinstance(node, LoopRegion): - body = make_empty_block() + body = make_empty_block(node) cfg_block = GeneralLoopScope(dispatch_state, parent_block, False, node, body) body.parent = cfg_block _structured_control_flow_traversal_with_regions(node, dispatch_state, body) else: - cfg_block = make_empty_block() + cfg_block = make_empty_block(node) cfg_block.region = node _structured_control_flow_traversal_with_regions(node, dispatch_state, cfg_block) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 11a198f119..cd98c26479 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -697,19 +697,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): if name in self.free_symbols(isedge.data): multistate = True for cfg in sdfg.all_control_flow_regions(): - block_syms = set() - if isinstance(cfg, LoopRegion): - block_syms |= symbolic.free_symbols_and_functions(cfg.loop_condition.as_string) - update_stmt = loop_analysis.get_update_assignment(cfg) - init_stmt = loop_analysis.get_init_assignment(cfg) - if update_stmt: - block_syms |= symbolic.free_symbols_and_functions(update_stmt) - if init_stmt: - block_syms |= symbolic.free_symbols_and_functions(init_stmt) - elif isinstance(cfg, ConditionalBlock): - for cond, _ in cfg.branches: - if cond is not None: - block_syms |= symbolic.free_symbols_and_functions(cond.as_string) + block_syms = cfg.used_symbols(all_symbols=True, with_contents=False) if name in block_syms: multistate = True diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index c05708670e..940114bbe2 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -1,7 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from dace import data, dtypes -from dace.codegen.tools import type_inference from dace.memlet import Memlet from dace.sdfg import SDFG, SDFGState, nodes, validation from dace.sdfg import nodes diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 992d593ddb..c6a7741628 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1306,7 +1306,8 @@ def _used_symbols_internal(self, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -1327,7 +1328,8 @@ def _used_symbols_internal(self, keep_defined_in_mapping=keep_defined_in_mapping, defined_syms=defined_syms, free_syms=free_syms, - used_before_assignment=used_before_assignment) + used_before_assignment=used_before_assignment, + with_contents=with_contents) def get_all_toplevel_symbols(self) -> Set[str]: """ diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 40aa866d03..5e5d07b288 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -215,13 +215,18 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi # Query, subgraph, and replacement methods @abc.abstractmethod - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Set[str]: """ Returns a set of symbol names that are used in the graph. :param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code). :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping will be removed from the set of defined symbols. + :param with_contents: Compute the symbols used including the ones used by the contents of the graph. If set to + False, only symbols used on the BlockGraphView itself are returned. The latter may + include symbols used in the conditions of conditional blocks, loops, etc. Defaults to + True. """ return set() @@ -645,7 +650,11 @@ def is_leaf_memlet(self, e): return False return True - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Set[str]: + if not with_contents: + return set() + state = self.graph if isinstance(self, SubgraphView) else self sdfg = state.sdfg new_symbols = set() @@ -693,17 +702,6 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) new_symbols.update(set(sdfg.constants.keys())) return freesyms - new_symbols - @property - def free_symbols(self) -> Set[str]: - """ - Returns a set of symbol names that are used, but not defined, in - this graph view (SDFG state or subgraph thereof). - - :note: Assumes that the graph is valid (i.e., without undefined or - overlapping symbols). - """ - return self.used_symbols(all_symbols=True) - def defined_symbols(self) -> Dict[str, dt.Data]: """ Returns a dictionary that maps currently-defined symbols in this SDFG @@ -1117,11 +1115,14 @@ def _used_symbols_internal(self, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: raise NotImplementedError() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: - return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0] + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Set[str]: + return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, + with_contents=with_contents)[0] def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set = set() @@ -2604,6 +2605,13 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, self._cached_start_block: Optional[ControlFlowBlock] = None self._cfg_list: List['ControlFlowRegion'] = [self] + def get_meta_read_memlets(self) -> List[mm.Memlet]: + """ + Get read memlets used by the control flow region itself, such as in condition checks for conditional blocks, or + in loop conditions for loops etc. + """ + return [] + @property def root_sdfg(self) -> 'SDFG': from dace.sdfg.sdfg import SDFG # Avoid import loop @@ -2661,29 +2669,37 @@ def state(self, state_id: int) -> SDFGState: raise TypeError(f'The node with id {state_id} is not an SDFGState') return node - def inline(self) -> Tuple[bool, Any]: + def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ Inlines the control flow region into its parent control flow region (if it exists). + :param lower_returns: Whether or not to remove explicit return blocks when inlining where possible. Defaults to + False. :return: True if the inlining succeeded, false otherwise. """ parent = self.parent_graph if parent: # Add all region states and make sure to keep track of all the ones that need to be connected in the end. - to_connect: Set[SDFGState] = set() + to_connect: Set[ControlFlowBlock] = set() + ends_context: Set[ControlFlowBlock] = set() block_to_state_map: Dict[ControlFlowBlock, SDFGState] = dict() for node in self.nodes(): node.label = self.label + '_' + node.label - if isinstance(node, ReturnBlock) and isinstance(parent, dace.SDFG): + if isinstance(node, ReturnBlock) and lower_returns and isinstance(parent, dace.SDFG): # If a return block is being inlined into an SDFG, convert it into a regular state. Otherwise it # remains as-is. newnode = parent.add_state(node.label) block_to_state_map[node] = newnode + if self.out_degree(node) == 0: + to_connect.add(newnode) + ends_context.add(newnode) else: parent.add_node(node, ensure_unique_name=True) - if self.out_degree(node) == 0 and not isinstance(node, (BreakBlock, ContinueBlock, ReturnBlock)): + if self.out_degree(node) == 0: to_connect.add(node) + if isinstance(node, (BreakBlock, ContinueBlock, ReturnBlock)): + ends_context.add(node) # Add all region edges. for edge in self.edges(): @@ -2705,7 +2721,10 @@ def inline(self) -> Tuple[bool, Any]: parent.remove_edge(a_edge) for node in to_connect: - parent.add_edge(node, end_state, dace.InterstateEdge()) + if node in ends_context: + parent.add_edge(node, end_state, dace.InterstateEdge(condition='False')) + else: + parent.add_edge(node, end_state, dace.InterstateEdge()) # Remove the original control flow region (self) from the parent graph. parent.remove_node(self) @@ -2908,45 +2927,45 @@ def _used_symbols_internal(self, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment - try: - ordered_blocks = self.bfs_nodes(self.start_block) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_blocks = self.nodes() - - for block in ordered_blocks: - state_symbols = set() - if isinstance(block, (ControlFlowRegion, ConditionalBlock)): - b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols, - defined_syms, - free_syms, - used_before_assignment, - keep_defined_in_mapping) - free_syms |= b_free_syms - defined_syms |= b_defined_syms - used_before_assignment |= b_used_before_syms - state_symbols = b_free_syms - else: - state_symbols = block.used_symbols(all_symbols, keep_defined_in_mapping) - free_syms |= state_symbols - - # Add free inter-state symbols - for e in self.out_edges(block): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - # collect symbols representing data containers - dsyms = {sym for sym in efsyms if sym in self.sdfg.arrays} - for d in dsyms: - efsyms |= {str(sym) for sym in self.sdfg.arrays[d].used_symbols(all_symbols)} - defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms + if with_contents: + try: + ordered_blocks = self.bfs_nodes(self.start_block) + except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) + ordered_blocks = self.nodes() + + for block in ordered_blocks: + state_symbols = set() + if isinstance(block, (ControlFlowRegion, ConditionalBlock)): + b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal( + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping, + with_contents) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms + state_symbols = b_free_syms + else: + state_symbols = block.used_symbols(all_symbols, keep_defined_in_mapping, with_contents) + free_syms |= state_symbols + + # Add free inter-state symbols + for e in self.out_edges(block): + # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols + # by subracting the (true) free symbols from the edge's assignment keys. This way we can correctly + # compute the symbols that are used before being assigned. + efsyms = e.data.used_symbols(all_symbols) + # collect symbols representing data containers + dsyms = {sym for sym in efsyms if sym in self.sdfg.arrays} + for d in dsyms: + efsyms |= {str(sym) for sym in self.sdfg.arrays[d].used_symbols(all_symbols)} + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) + used_before_assignment.update(efsyms - defined_syms) + free_syms |= efsyms # Remove symbols that were used before they were assigned. defined_syms -= used_before_assignment @@ -3152,10 +3171,12 @@ def __init__(self, self.inverted = inverted self.update_before_condition = update_before_condition - def inline(self) -> Tuple[bool, Any]: + def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ Inlines the loop region into its parent control flow region. + :param lower_returns: Whether or not to remove explicit return blocks when inlining where possible. Defaults to + False. :return: True if the inlining succeeded, false otherwise. """ parent = self.parent_graph @@ -3184,7 +3205,7 @@ def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: if ((isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) and not isinstance(block, LoopRegion)): recursive_inline_cf_regions(block) - block.inline() + block.inline(lower_returns=lower_returns) recursive_inline_cf_regions(self) # Add all boilerplate loop states necessary for the structure. @@ -3273,12 +3294,23 @@ def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: return True, (init_state, guard_state, end_state) + def get_meta_read_memlets(self) -> List[mm.Memlet]: + # Avoid cyclic imports. + from dace.sdfg.sdfg import memlets_in_ast + read_memlets = memlets_in_ast(self.loop_condition.code[0], self.sdfg.arrays) + if self.init_statement: + read_memlets.extend(memlets_in_ast(self.init_statement.code[0], self.sdfg.arrays)) + if self.update_statement: + read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays)) + return read_memlets + def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -3293,7 +3325,7 @@ def _used_symbols_internal(self, cond_free_syms.remove(self.loop_variable) b_free_symbols, b_defined_symbols, b_used_before_assignment = super()._used_symbols_internal( - all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) + all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, with_contents=with_contents) outside_defined = defined_syms - used_before_assignment used_before_assignment |= ((b_used_before_assignment - {self.loop_variable}) - outside_defined) free_syms |= b_free_symbols @@ -3403,13 +3435,23 @@ def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): def remove_branch(self, branch: ControlFlowRegion): self._branches = [(c, b) for c, b in self._branches if b is not branch] + + def get_meta_read_memlets(self) -> List[mm.Memlet]: + # Avoid cyclic imports. + from dace.sdfg.sdfg import memlets_in_ast + read_memlets = [] + for c, _ in self.branches: + if c is not None: + read_memlets.extend(memlets_in_ast(c.code[0], self.sdfg.arrays)) + return read_memlets def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -3418,7 +3460,7 @@ def _used_symbols_internal(self, if condition is not None: free_syms |= condition.get_free_symbols(defined_syms) b_free_symbols, b_defined_symbols, b_used_before_assignment = region._used_symbols_internal( - all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping, with_contents) free_syms |= b_free_symbols defined_syms |= b_defined_symbols used_before_assignment |= b_used_before_assignment @@ -3468,11 +3510,13 @@ def from_json(cls, json_obj, context=None): else: ret.add_branch(None, ControlFlowRegion.from_json(region, context)) return ret - - def inline(self) -> Tuple[bool, Any]: + + def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ Inlines the conditional region into its parent control flow region. + :param lower_returns: Whether or not to remove explicit return blocks when inlining where possible. Defaults to + False. :return: True if the inlining succeeded, false otherwise. """ parent = self.parent_graph @@ -3506,7 +3550,7 @@ def inline(self) -> Tuple[bool, Any]: parent.add_node(region) parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) parent.add_edge(region, end_state, InterstateEdge()) - region.inline() + region.inline(lower_returns=lower_returns) if full_cond_expression is not None: negative_full_cond = astutils.negate_expr(full_cond_expression) negative_cond = CodeBlock([negative_full_cond]) @@ -3517,7 +3561,7 @@ def inline(self) -> Tuple[bool, Any]: parent.add_node(else_branch) parent.add_edge(guard_state, else_branch, InterstateEdge(condition=negative_cond)) parent.add_edge(else_branch, end_state, InterstateEdge()) - else_branch.inline() + else_branch.inline(lower_returns=lower_returns) else: parent.add_edge(guard_state, end_state, InterstateEdge(condition=negative_cond)) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 7cd438be88..dddbfb7652 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1306,7 +1306,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractControlFlowRegion]]] = None, ignore_region_types: Optional[List[Type[AbstractControlFlowRegion]]] = None, - progress: bool = None) -> int: + progress: bool = None, lower_returns: bool = False, + eliminate_dead_states: bool = False) -> int: if types: blocks = [n for n, _ in sdfg.all_nodes_recursive() if type(n) in types] elif ignore_region_types: @@ -1324,8 +1325,12 @@ def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractCo # Control flow regions where the parent is a conditional block are not inlined. if block.parent_graph and type(block.parent_graph) == ConditionalBlock: continue - if block.inline()[0]: + if block.inline(lower_returns=lower_returns)[0]: count += 1 + if eliminate_dead_states: + # Avoid cyclic imports. + from dace.transformation.passes.dead_state_elimination import DeadStateElimination + DeadStateElimination().apply_pass(sdfg, {}) sdfg.reset_cfg_list() diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 3a82e05976..5365d2ce35 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -5,8 +5,7 @@ from dace.sdfg import nodes, scope from dace.sdfg import utils as sdutil from dace.sdfg.replace import replace_in_codeblock -from dace.sdfg.sdfg import memlets_in_ast -from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, LoopRegion, SDFGState from dace.transformation import transformation, helpers as xfh from dace.properties import ListProperty, Property, make_properties from collections import defaultdict @@ -290,16 +289,8 @@ def apply(self, _, sdfg: sd.SDFG): for edge in sdfg.all_interstate_edges(): check_memlets.extend(edge.data.get_read_memlets(sdfg.arrays)) for blk in sdfg.all_control_flow_blocks(): - if isinstance(blk, ConditionalBlock): - for c, _ in blk.branches: - if c is not None: - check_memlets.extend(memlets_in_ast(c.code[0], sdfg.arrays)) - elif isinstance(blk, LoopRegion): - check_memlets.extend(memlets_in_ast(blk.loop_condition.code[0], sdfg.arrays)) - if blk.init_statement: - check_memlets.extend(memlets_in_ast(blk.init_statement.code[0], sdfg.arrays)) - if blk.update_statement: - check_memlets.extend(memlets_in_ast(blk.update_statement.code[0], sdfg.arrays)) + if isinstance(blk, AbstractControlFlowRegion): + check_memlets.extend(blk.get_meta_read_memlets()) for mem in check_memlets: if sdfg.arrays[mem.data].storage == dtypes.StorageType.GPU_Global: data_already_on_gpu[mem.data] = None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index b89a09b196..6ef900f28f 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -1,13 +1,19 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Optional, Tuple +import ast +from typing import List, Optional, Tuple + import networkx as nx +import sympy + from dace import properties +from dace.frontend.python import astutils from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import ConditionalBlock, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, ReturnBlock from dace.sdfg.utils import dfs_conditional -from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation import pass_pipeline as ppl +from dace.transformation import transformation from dace.transformation.interstate.loop_lifting import LoopLifting @@ -20,12 +26,84 @@ class ControlFlowRaising(ppl.Pass): CATEGORY: str = 'Simplification' + raise_sink_node_returns = properties.Property( + dtype=bool, + default=False, + desc='Whether or not to lift sink nodes in an SDFG context to explicit return blocks.') + def modifies(self) -> ppl.Modifies: return ppl.Modifies.CFG def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG + def _lift_returns(self, sdfg: SDFG) -> int: + """ + Make any implicit early program exits explicit by inserting return blocks. + An implicit early program exit is a control flow block with not at least one unconditional edge leading out of + it, or where there is no 'catchall' condition that negates all other conditions. For any such transition, if + the condition(s) is / are not met, the SDFG halts. + This method detects such situations and inserts an explicit transition to a return block for each such missing + unconditional edge or 'catchall' condition. Note that this is only performed on the top-level control flow + region, i.e., the SDFG itself. Any implicit early stops inside nested regions only end the context of that + region, and not the entire SDFG. + + :param sdfg: The SDFG in which to lift returns + :returns: The number of return blocks lifted + """ + returns_lifted = 0 + for nd in sdfg.nodes(): + # Existing returns can be skipped. + if isinstance(nd, ReturnBlock): + continue + + # First check if there is an unconditional outgoing edge. + has_unconditional = False + full_cond_expression: Optional[List[ast.AST]] = None + oedges = sdfg.out_edges(nd) + for oe in oedges: + if oe.data.is_unconditional(): + has_unconditional = True + break + else: + if full_cond_expression is None: + full_cond_expression = oe.data.condition.code[0] + else: + full_cond_expression = astutils.and_expr(full_cond_expression, oe.data.condition.code[0]) + # If there is no unconditional outgoing edge, there may be a catchall that is the negation of all other + # conditions. + # NOTE: Checking that for the general case is expensive. For now, we check it for the case of two outgoing + # edges, where the two edges are a negation of one another, which is cheap. In any other case, an + # explicit return is added with the negation of everything. This is conservative and always correct, + # but may insert a stray (and unreachable) return in rare cases. That case should hardly ever occur + # and does not lead to any negative side effects. + if has_unconditional: + insert_return = False + else: + if len(oedges) == 2 and oedges[0].data.condition_sympy() == sympy.Not(oedges[1].data.condition_sympy()): + insert_return = False + else: + insert_return = True + + if insert_return: + if full_cond_expression is None: + # If there is no condition, there are no outgoing edges - so this is already an explicit program + # exit by being a sink node. + if self.raise_sink_node_returns: + ret_block = ReturnBlock(sdfg.name + '_return') + sdfg.add_node(ret_block, ensure_unique_name=True) + sdfg.add_edge(nd, ret_block, InterstateEdge()) + returns_lifted += 1 + else: + ret_block = ReturnBlock(nd.label + '_return') + sdfg.add_node(ret_block, ensure_unique_name=True) + catchall_condition_expression = astutils.negate_expr(full_cond_expression) + ret_edge = InterstateEdge(condition=properties.CodeBlock([catchall_condition_expression])) + sdfg.add_edge(nd, ret_block, ret_edge) + returns_lifted += 1 + + return returns_lifted + def _lift_conditionals(self, sdfg: SDFG) -> int: cfgs = list(sdfg.all_control_flow_regions()) n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) @@ -34,10 +112,16 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: if isinstance(region, ConditionalBlock): continue - sinks = region.sink_nodes() - dummy_exit = region.add_state('__DACE_DUMMY') - for s in sinks: - region.add_edge(s, dummy_exit, InterstateEdge()) + # If there are multiple sinks, create a dummy exit node for finding branch merges. If there is at least one + # non-return block sink, do not count return blocks as sink nodes. Doing so could cause branches to inter- + # connect unnecessarily, thus preventing lifting. + non_return_sinks = [s for s in region.sink_nodes() if not isinstance(s, ReturnBlock)] + sinks = non_return_sinks if len(non_return_sinks) > 0 else region.sink_nodes() + dummy_exit = None + if len(sinks) > 1: + dummy_exit = region.add_state('__DACE_DUMMY') + for s in sinks: + region.add_edge(s, dummy_exit, InterstateEdge()) idom = nx.immediate_dominators(region.nx, region.start_block) alldoms = cfg_analysis.all_dominators(region, idom) branch_merges = cfg_analysis.branch_merges(region, idom, alldoms) @@ -68,7 +152,6 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) branch_start = branch.add_state(branch_name + '_start', is_start_block=True) branch.add_nodes_from(branch_nodes) - branch_end = branch.add_state(branch_name + '_end') branch.add_edge(branch_start, oe.dst, InterstateEdge(assignments=oe.data.assignments)) added = set() for e in graph.all_edges(*branch_nodes): @@ -77,15 +160,17 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: if e is oe: continue elif e.dst is merge_block: - branch.add_edge(e.src, branch_end, e.data) + if e.data.assignments or not e.data.is_unconditional(): + branch.add_edge(e.src, branch.add_state(branch_name + '_end'), e.data) else: branch.add_edge(e.src, e.dst, e.data) graph.remove_nodes_from(branch_nodes) # Connect to the end of the branch / what happens after. - if merge_block is not dummy_exit: + if dummy_exit is None or merge_block is not dummy_exit: graph.add_edge(conditional, merge_block, InterstateEdge()) - region.remove_node(dummy_exit) + if dummy_exit is not None: + region.remove_node(dummy_exit) n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) lifted = n_cond_regions_post - n_cond_regions_pre @@ -93,13 +178,21 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: sdfg.root_sdfg.using_experimental_blocks = True return lifted - def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int, int]]: + lifted_returns = 0 lifted_loops = 0 lifted_branches = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): + lifted_returns += self._lift_returns(sdfg) lifted_loops += sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) lifted_branches += self._lift_conditionals(sdfg) if lifted_branches == 0 and lifted_loops == 0: return None top_sdfg.reset_cfg_list() - return lifted_loops, lifted_branches + return lifted_returns, lifted_loops, lifted_branches + + def report(self, pass_retval: Optional[Tuple[int, int, int]]): + if pass_retval and any([x > 0 for x in pass_retval]): + return f'Lifted {pass_retval[0]} returns, {pass_retval[1]} loops, and {pass_retval[2]} conditional blocks' + else: + return 'No control flow lifted' diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py index 22701fb2bb..8b22446974 100644 --- a/tests/passes/simplification/control_flow_raising_test.py +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -1,14 +1,16 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import pytest import dace import numpy as np from dace.sdfg.state import ConditionalBlock -from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline -from dace.transformation.passes.fusion_inline import InlineControlFlowRegions +from dace.sdfg.utils import inline_control_flow_regions +from dace.transformation.pass_pipeline import FixedPointPipeline from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising -def test_dataflow_if_check(): +@pytest.mark.parametrize('lowered_returns', [False, True]) +def test_dataflow_if_check(lowered_returns: bool): @dace.program def dataflow_if_check(A: dace.int32[10], i: dace.int64): @@ -21,12 +23,7 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): sdfg = dataflow_if_check.to_sdfg() # To test raising, we inline the control flow generated by the frontend. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg, lower_returns=lowered_returns) assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -44,7 +41,8 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): assert sdfg(A, 6)[0] == 0 -def test_nested_if_chain(): +@pytest.mark.parametrize('lowered_returns', [False, True]) +def test_nested_if_chain(lowered_returns: bool): @dace.program def nested_if_chain(i: dace.int64): @@ -65,12 +63,7 @@ def nested_if_chain(i: dace.int64): sdfg = nested_if_chain.to_sdfg() # To test raising, we inline the control flow generated by the frontend. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg, lower_returns=lowered_returns) assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -86,7 +79,8 @@ def nested_if_chain(i: dace.int64): assert nested_if_chain(15)[0] == 4 -def test_elif_chain(): +@pytest.mark.parametrize('lowered_returns', [False, True]) +def test_elif_chain(lowered_returns: bool): @dace.program def elif_chain(i: dace.int64): @@ -104,12 +98,7 @@ def elif_chain(i: dace.int64): sdfg = elif_chain.to_sdfg() # To test raising, we inline the control flow generated by the frontend. - inliner = InlineControlFlowRegions() - inliner.no_inline_conditional = False - inliner.no_inline_loops = False - inliner.no_inline_function_call_regions = False - inliner.no_inline_named_regions = False - inliner.apply_pass(sdfg, {}) + inline_control_flow_regions(sdfg, lower_returns=lowered_returns) assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -126,6 +115,9 @@ def elif_chain(i: dace.int64): if __name__ == '__main__': - test_dataflow_if_check() - test_nested_if_chain() - test_elif_chain() + test_dataflow_if_check(False) + test_dataflow_if_check(True) + test_nested_if_chain(False) + test_nested_if_chain(True) + test_elif_chain(False) + test_elif_chain(True) From c9d6b511c3919a27b1e5b79b1deda8b559befd42 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 11 Dec 2024 15:08:07 +0100 Subject: [PATCH 104/108] Renamed experimental_cfg_blocks to explicit_control_flow --- dace/codegen/control_flow.py | 2 +- dace/codegen/targets/framecode.py | 2 +- dace/frontend/fortran/fortran_parser.py | 18 +++---- dace/frontend/python/interface.py | 7 ++- dace/frontend/python/parser.py | 8 ++-- .../analysis/schedule_tree/sdfg_to_tree.py | 2 +- .../analysis/writeset_underapproximation.py | 2 +- dace/sdfg/propagation.py | 2 +- dace/sdfg/sdfg.py | 16 +++---- dace/transformation/__init__.py | 2 +- dace/transformation/dataflow/map_fission.py | 2 +- dace/transformation/dataflow/map_for_loop.py | 2 +- .../transformation/interstate/block_fusion.py | 2 +- .../interstate/fpga_transform_sdfg.py | 2 +- .../interstate/fpga_transform_state.py | 2 +- .../interstate/gpu_transform_sdfg.py | 2 +- .../interstate/loop_detection.py | 2 +- .../transformation/interstate/loop_lifting.py | 4 +- .../transformation/interstate/loop_peeling.py | 4 +- dace/transformation/interstate/loop_to_map.py | 2 +- dace/transformation/interstate/loop_unroll.py | 2 +- .../interstate/move_assignment_outside_if.py | 2 +- .../interstate/move_loop_into_map.py | 2 +- .../interstate/multistate_inline.py | 11 +---- .../transformation/interstate/sdfg_nesting.py | 8 ++-- .../interstate/state_elimination.py | 14 +++--- .../transformation/interstate/state_fusion.py | 2 +- .../state_fusion_with_happens_before.py | 2 +- .../interstate/trivial_loop_elimination.py | 2 +- .../passes/analysis/analysis.py | 24 +++++----- .../passes/array_elimination.py | 2 +- .../passes/consolidate_edges.py | 4 +- .../passes/constant_propagation.py | 2 +- .../passes/dead_dataflow_elimination.py | 2 +- .../passes/dead_state_elimination.py | 2 +- dace/transformation/passes/fusion_inline.py | 10 ++-- dace/transformation/passes/optional_arrays.py | 2 +- .../transformation/passes/pattern_matching.py | 24 +++++----- dace/transformation/passes/prune_symbols.py | 2 +- .../passes/reference_reduction.py | 2 +- dace/transformation/passes/scalar_fission.py | 2 +- .../transformation/passes/scalar_to_symbol.py | 4 +- .../simplification/control_flow_raising.py | 4 +- .../prune_empty_conditional_branches.py | 2 +- dace/transformation/passes/simplify.py | 12 ++--- dace/transformation/passes/symbol_ssa.py | 2 +- dace/transformation/passes/transient_reuse.py | 4 +- dace/transformation/subgraph/composite.py | 2 +- dace/transformation/transformation.py | 8 ++-- tests/fortran/fortran_loops_test.py | 2 +- .../conditional_regions_test.py | 6 +-- tests/python_frontend/loop_regions_test.py | 48 +++++++++---------- tests/sdfg/loop_region_test.py | 12 ++--- tests/sdfg/work_depth_test.py | 2 +- .../interstate/loop_lifting_test.py | 8 ++-- tests/transformations/loop_detection_test.py | 2 +- 56 files changed, 159 insertions(+), 167 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 31988ba700..5928cc71f2 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -1144,7 +1144,7 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState :param sdfg: The SDFG to iterate over. :return: Control-flow block representing the entire SDFG. """ - if sdfg.root_sdfg.using_experimental_blocks: + if sdfg.root_sdfg.using_explicit_control_flow: return structured_control_flow_tree_with_regions(sdfg, dispatch_state) # Avoid import loops diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index cd98c26479..f760715ef9 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -485,7 +485,7 @@ def dispatch_state(state: SDFGState) -> str: states_generated.add(state) # For sanity check return stream.getvalue() - if sdfg.root_sdfg.recheck_using_experimental_blocks(): + if sdfg.root_sdfg.recheck_using_explicit_control_flow(): # Use control flow blocks embedded in the SDFG to generate control flow. cft = cflow.structured_control_flow_tree_with_regions(sdfg, dispatch_state) elif config.Config.get_bool('optimizer', 'detect_control_flow'): diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 1cdecc99a8..6b14f63edd 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -29,7 +29,7 @@ class AST_translator: """ This class is responsible for translating the internal AST into a SDFG. """ - def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_experimental_cfg_blocks: bool = False): + def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_explicit_cf: bool = False): """ :ast: The internal fortran AST to be used for translation :source: The source file name from which the AST was generated @@ -69,7 +69,7 @@ def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_expe ast_internal_classes.Allocate_Stmt_Node: self.allocate2sdfg, ast_internal_classes.Break_Node: self.break2sdfg, } - self.use_experimental_cfg_blocks = use_experimental_cfg_blocks + self.use_explicit_cf = use_explicit_cf def get_dace_type(self, type): """ @@ -271,7 +271,7 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg :param sdfg: The SDFG to which the node should be translated """ - if not self.use_experimental_cfg_blocks: + if not self.use_explicit_cf: declloop = False name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) begin_state = ast_utils.add_simple_state_to_sdfg(self, cfg, "Begin" + name) @@ -1103,7 +1103,7 @@ def create_sdfg_from_string( source_string: str, sdfg_name: str, normalize_offsets: bool = False, - use_experimental_cfg_blocks: bool = False + use_explicit_cf: bool = False ): """ Creates an SDFG from a fortran file in a string @@ -1133,7 +1133,7 @@ def create_sdfg_from_string( program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) - ast2sdfg = AST_translator(own_ast, __file__, use_experimental_cfg_blocks) + ast2sdfg = AST_translator(own_ast, __file__, use_explicit_cf) sdfg = SDFG(sdfg_name) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg @@ -1148,11 +1148,11 @@ def create_sdfg_from_string( sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None sdfg.reset_cfg_list() - sdfg.using_experimental_blocks = use_experimental_cfg_blocks + sdfg.using_explicit_control_flow = use_explicit_cf return sdfg -def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_blocks: bool = False): +def create_sdfg_from_fortran_file(source_string: str, use_explicit_cf: bool = False): """ Creates an SDFG from a fortran file :param source_string: The fortran file name @@ -1180,11 +1180,11 @@ def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_block program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program).visit(program) - ast2sdfg = AST_translator(own_ast, __file__, use_experimental_cfg_blocks) + ast2sdfg = AST_translator(own_ast, __file__, use_explicit_cf) sdfg = SDFG(source_string) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg ast2sdfg.translate(program, sdfg) - sdfg.using_experimental_blocks = use_experimental_cfg_blocks + sdfg.using_explicit_control_flow = use_explicit_cf return sdfg diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 06bef0ba37..aa8caef826 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -44,7 +44,7 @@ def program(f: F, recompile: bool = True, distributed_compilation: bool = False, constant_functions=False, - use_experimental_cfg_blocks=True, + use_explicit_cf=True, **kwargs) -> Callable[..., parser.DaceProgram]: """ Entry point to a data-centric program. For methods and ``classmethod``s, use @@ -69,8 +69,7 @@ def program(f: F, not depend on internal variables are constant. This will hardcode their return values into the resulting program. - :param use_experimental_cfg_blocks: If True, makes use of experimental CFG blocks susch as loop and conditional - regions. + :param use_explicit_cfl: If True, makes use of explicit control flow constructs. :note: If arguments are defined with type hints, the program can be compiled ahead-of-time with ``.compile()``. """ @@ -87,7 +86,7 @@ def program(f: F, regenerate_code=regenerate_code, recompile=recompile, distributed_compilation=distributed_compilation, - use_experimental_cfg_blocks=use_experimental_cfg_blocks) + use_explicit_cf=use_explicit_cf) function = program diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index dc3667145d..0faa2e36ce 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -156,7 +156,7 @@ def __init__(self, recompile: bool = True, distributed_compilation: bool = False, method: bool = False, - use_experimental_cfg_blocks: bool = True): + use_explicit_cf: bool = True): from dace.codegen import compiled_sdfg # Avoid import loops self.f = f @@ -176,7 +176,7 @@ def __init__(self, self.recreate_sdfg = recreate_sdfg self.regenerate_code = regenerate_code self.recompile = recompile - self.use_experimental_cfg_blocks = use_experimental_cfg_blocks + self.use_explicit_cf = use_explicit_cf self.distributed_compilation = distributed_compilation self.global_vars = _get_locals_and_globals(f) @@ -494,10 +494,10 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) - if not self.use_experimental_cfg_blocks: + if not self.use_explicit_cf: for nsdfg in sdfg.all_sdfgs_recursive(): sdutils.inline_control_flow_regions(nsdfg) - sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + sdfg.using_explicit_control_flow = self.use_explicit_cf sdfg.reset_cfg_list() diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 46ad04f70b..e0bc95ad34 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -652,7 +652,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) ############################# # Create initial tree from CFG - if sdfg.using_experimental_blocks: + if sdfg.using_explicit_control_flow: cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '') else: cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index a0f84e93a6..0426cb0942 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -685,7 +685,7 @@ class UnderapproximateWritesDict: Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] = field(default_factory=dict) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class UnderapproximateWrites(ppl.Pass): # Dictionary mapping each edge to a copy of the memlet of that edge with its write set underapproximated. diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index bf1f03197e..2983ec3c63 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -744,7 +744,7 @@ def propagate_states(sdfg: 'SDFG', concretize_dynamic_unbounded: bool = False) - :note: This operates on the SDFG in-place. """ - if sdfg.using_experimental_blocks: + if sdfg.using_explicit_control_flow: # Avoid cyclic imports from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis import StatePropagation diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index c6a7741628..4a141aef12 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -460,8 +460,8 @@ class SDFG(ControlFlowRegion): desc='Mapping between callback name and its original callback ' '(for when the same callback is used with a different signature)') - using_experimental_blocks = Property(dtype=bool, default=False, - desc="Whether the SDFG contains experimental control flow blocks") + using_explicit_control_flow = Property(dtype=bool, default=False, + desc="Whether the SDFG contains explicit control flow constructs") def __init__(self, name: str, @@ -2844,14 +2844,14 @@ def make_array_memlet(self, array: str): """ return dace.Memlet.from_array(array, self.data(array)) - def recheck_using_experimental_blocks(self) -> bool: - found_experimental_block = False + def recheck_using_explicit_control_flow(self) -> bool: + found_explicit_cf_block = False for node, graph in self.root_sdfg.all_nodes_recursive(): if isinstance(graph, ControlFlowRegion) and not isinstance(graph, SDFG): - found_experimental_block = True + found_explicit_cf_block = True break if isinstance(node, ControlFlowBlock) and not isinstance(node, SDFGState): - found_experimental_block = True + found_explicit_cf_block = True break - self.root_sdfg.using_experimental_blocks = found_experimental_block - return found_experimental_block + self.root_sdfg.using_explicit_control_flow = found_explicit_cf_block + return found_explicit_cf_block diff --git a/dace/transformation/__init__.py b/dace/transformation/__init__.py index 0b27542ca6..7f1a1fb064 100644 --- a/dace/transformation/__init__.py +++ b/dace/transformation/__init__.py @@ -1,4 +1,4 @@ from .transformation import (PatternNode, PatternTransformation, SingleStateTransformation, MultiStateTransformation, SubgraphTransformation, ExpandTransformation, - experimental_cfg_block_compatible, single_level_sdfg_only) + explicit_cf_compatible, single_level_sdfg_only) from .pass_pipeline import Pass, Pipeline, FixedPointPipeline diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index d9798fe81a..9f40a36b4d 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -15,7 +15,7 @@ from typing import List, Optional, Tuple -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class MapFission(transformation.SingleStateTransformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, diff --git a/dace/transformation/dataflow/map_for_loop.py b/dace/transformation/dataflow/map_for_loop.py index f224e9dbcf..0cd872d97d 100644 --- a/dace/transformation/dataflow/map_for_loop.py +++ b/dace/transformation/dataflow/map_for_loop.py @@ -112,6 +112,6 @@ def replace_param(param): sdfg.reset_cfg_list() # Ensure the SDFG is marked as containing CFG regions - sdfg.root_sdfg.using_experimental_blocks = True + sdfg.root_sdfg.using_explicit_control_flow = True return node, nstate diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py index a5475cf9cf..cf180ad771 100644 --- a/dace/transformation/interstate/block_fusion.py +++ b/dace/transformation/interstate/block_fusion.py @@ -5,7 +5,7 @@ from dace.transformation import transformation -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class BlockFusion(transformation.MultiStateTransformation): """ Implements the block-fusion transformation. diff --git a/dace/transformation/interstate/fpga_transform_sdfg.py b/dace/transformation/interstate/fpga_transform_sdfg.py index 5c2acf1d64..09a6ee2aa8 100644 --- a/dace/transformation/interstate/fpga_transform_sdfg.py +++ b/dace/transformation/interstate/fpga_transform_sdfg.py @@ -9,7 +9,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FPGATransformSDFG(transformation.MultiStateTransformation): """ Implements the FPGATransformSDFG transformation, which takes an entire SDFG and transforms it into an FPGA-capable SDFG. """ diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index 602cca6df1..dc888d8c33 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -31,7 +31,7 @@ def fpga_update(sdfg: SDFG, state: SDFGState, depth: int): fpga_update(node.sdfg, s, depth + 1) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FPGATransformState(transformation.MultiStateTransformation): """ Implements the FPGATransformState transformation. """ diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 5365d2ce35..901b05cb64 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -85,7 +85,7 @@ def _recursive_in_check(node, state, gpu_scalars): @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class GPUTransformSDFG(transformation.MultiStateTransformation): """ Implements the GPUTransformSDFG transformation. diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 8081447132..fbd627eeeb 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -12,7 +12,7 @@ # NOTE: This class extends PatternTransformation directly in order to not show up in the matches -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class DetectLoop(transformation.PatternTransformation): """ Detects a for-loop construct from an SDFG. """ diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 072c2519ed..746910964c 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -8,7 +8,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class LoopLifting(DetectLoop, transformation.MultiStateTransformation): def can_be_applied(self, graph: transformation.ControlFlowRegion, expr_index: int, sdfg: transformation.SDFG, @@ -95,5 +95,5 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): for n in full_body: graph.remove_node(n) - sdfg.root_sdfg.using_experimental_blocks = True + sdfg.root_sdfg.using_explicit_control_flow = True sdfg.reset_cfg_list() diff --git a/dace/transformation/interstate/loop_peeling.py b/dace/transformation/interstate/loop_peeling.py index 710f6f5d97..94174ab309 100644 --- a/dace/transformation/interstate/loop_peeling.py +++ b/dace/transformation/interstate/loop_peeling.py @@ -11,11 +11,11 @@ from dace.symbolic import pystr_to_symbolic from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.passes.analysis import loop_analysis -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class LoopPeeling(LoopUnroll): """ Splits the first `count` iterations of loop into multiple, separate control flow regions (one per iteration). diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 230545eecf..9b1c460372 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -73,7 +73,7 @@ def _sanitize_by_index(indices: Set[int], subset: subsets.Subset) -> subsets.Ran @properties.make_properties -@xf.experimental_cfg_block_compatible +@xf.explicit_cf_compatible class LoopToMap(xf.MultiStateTransformation): """ Convert a control flow loop into a dataflow map. Currently only supports the simple case where there is no overlap diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index a41dc95dc5..a23777c749 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -13,7 +13,7 @@ from dace.transformation.passes.analysis import loop_analysis @make_properties -@xf.experimental_cfg_block_compatible +@xf.explicit_cf_compatible class LoopUnroll(xf.MultiStateTransformation): """ Unrolls a for-loop into multiple individual control flow regions """ diff --git a/dace/transformation/interstate/move_assignment_outside_if.py b/dace/transformation/interstate/move_assignment_outside_if.py index 6522c67eb8..8cfaa591d7 100644 --- a/dace/transformation/interstate/move_assignment_outside_if.py +++ b/dace/transformation/interstate/move_assignment_outside_if.py @@ -15,7 +15,7 @@ from dace.transformation import transformation -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class MoveAssignmentOutsideIf(transformation.MultiStateTransformation): conditional = transformation.PatternNode(ConditionalBlock) diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index fd7c4353dd..de898c8f5c 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -24,7 +24,7 @@ def offset(memlet_subset_ranges, value): return (memlet_subset_ranges[0] + value, memlet_subset_ranges[1] + value, memlet_subset_ranges[2]) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class MoveLoopIntoMap(transformation.MultiStateTransformation): """ Moves a loop around a map into the map diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 45de07a5a6..89f0edcea9 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -18,7 +18,7 @@ @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class InlineMultistateSDFG(transformation.SingleStateTransformation): """ Inlines a multi-state nested SDFG into a top-level SDFG. This only happens @@ -141,14 +141,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): if isinstance(blk, ReturnBlock): has_return = True if has_return: - # Avoid cyclic imports - from dace.transformation.passes.fusion_inline import InlineControlFlowRegions - from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising - - sdutil.inline_control_flow_regions(nsdfg) - # After inlining, try to lift out control flow again, essentially preserving all control flow that can be - # preserved while removing the return blocks. - ControlFlowRaising().apply_pass(nsdfg, {}) + sdutil.inline_control_flow_regions(nsdfg, lower_returns=True) if nsdfg_node.schedule != dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types(nsdfg, [nsdfg_node.schedule]) diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index af260a651f..31e751bb6a 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -23,7 +23,7 @@ @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class InlineSDFG(transformation.SingleStateTransformation): """ Inlines a single-state nested SDFG into a top-level SDFG. @@ -738,7 +738,7 @@ def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str], new @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class InlineTransients(transformation.SingleStateTransformation): """ Inlines all transient arrays that are not used anywhere else into a nested SDFG. @@ -881,7 +881,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class RefineNestedAccess(transformation.SingleStateTransformation): """ Reduces memlet shape when a memlet is connected to a nested SDFG, but not @@ -1106,7 +1106,7 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class NestSDFG(transformation.MultiStateTransformation): """ Implements SDFG Nesting, taking an SDFG as an input and creating a nested SDFG node from it. """ diff --git a/dace/transformation/interstate/state_elimination.py b/dace/transformation/interstate/state_elimination.py index 8755155615..94619576bf 100644 --- a/dace/transformation/interstate/state_elimination.py +++ b/dace/transformation/interstate/state_elimination.py @@ -13,7 +13,7 @@ from dace.transformation import transformation -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class EndStateElimination(transformation.MultiStateTransformation): """ End-state elimination removes a redundant state that has one incoming edge @@ -61,7 +61,7 @@ def apply(self, graph, sdfg): sdfg.remove_symbol(sym) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StartStateElimination(transformation.MultiStateTransformation): """ Start-state elimination removes a redundant state that has one outgoing edge @@ -134,7 +134,7 @@ def _assignments_to_consider(sdfg, edge, is_constant=False): return assignments_to_consider -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateAssignElimination(transformation.MultiStateTransformation): """ State assign elimination removes all assignments into the final state @@ -231,7 +231,7 @@ def _alias_assignments(sdfg: SDFG, edge: InterstateEdge): return assignments_to_consider -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class SymbolAliasPromotion(transformation.MultiStateTransformation): """ SymbolAliasPromotion moves inter-state assignments that create symbolic @@ -336,7 +336,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): in_edge.assignments[k] = v -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class HoistState(transformation.SingleStateTransformation): """ Move a state out of a nested SDFG """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @@ -492,7 +492,7 @@ def replfunc(m): nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class TrueConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always true, removes condition from edge. @@ -528,7 +528,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): edge.data.condition = CodeBlock("1") -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FalseConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always false, removes edge. diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index dbdf7642bd..7e3dc6916b 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -32,7 +32,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateFusion(transformation.MultiStateTransformation): """ Implements the state-fusion transformation. diff --git a/dace/transformation/interstate/state_fusion_with_happens_before.py b/dace/transformation/interstate/state_fusion_with_happens_before.py index ae2007e59f..c358a131f6 100644 --- a/dace/transformation/interstate/state_fusion_with_happens_before.py +++ b/dace/transformation/interstate/state_fusion_with_happens_before.py @@ -31,7 +31,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateFusionExtended(transformation.MultiStateTransformation): """ Implements the state-fusion transformation extended to fuse states with RAW and WAW dependencies. An empty memlet is used to represent a dependency between two subgraphs with RAW and WAW dependencies. diff --git a/dace/transformation/interstate/trivial_loop_elimination.py b/dace/transformation/interstate/trivial_loop_elimination.py index 981d3833b6..e948cba7ba 100644 --- a/dace/transformation/interstate/trivial_loop_elimination.py +++ b/dace/transformation/interstate/trivial_loop_elimination.py @@ -9,7 +9,7 @@ from dace.transformation.passes.analysis import loop_analysis -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class TrivialLoopElimination(transformation.MultiStateTransformation): """ Eliminates loops with a single loop iteration. diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index ffd8a6134f..94c24399ee 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -25,7 +25,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateReachability(ppl.Pass): """ Evaluates state reachability (which other states can be executed after each state). @@ -64,7 +64,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ControlFlowBlockReachability(ppl.Pass): """ Evaluates control flow block reachability (which control flow block can be executed after each control flow block) @@ -195,7 +195,7 @@ def reachable_nodes(G): @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class SymbolAccessSets(ppl.ControlFlowRegionPass): """ Evaluates symbol access sets (which symbols are read/written in each control flow block or interstate edge). @@ -225,7 +225,7 @@ def apply(self, region: ControlFlowRegion, _) -> Dict[Union[ControlFlowBlock, Ed @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class AccessSets(ppl.Pass): """ Evaluates memory access sets (which arrays/data descriptors are read/written in each control flow block). @@ -295,7 +295,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str] @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FindAccessStates(ppl.Pass): """ For each data descriptor, creates a set of states in which access nodes of that data are used. @@ -334,7 +334,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FindAccessNodes(ppl.Pass): """ For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given @@ -371,7 +371,7 @@ def apply_pass(self, top_sdfg: SDFG, @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class SymbolWriteScopes(ppl.ControlFlowRegionPass): """ For each symbol, create a dictionary mapping each interstate edge writing to that symbol to the set of interstate @@ -472,7 +472,7 @@ def apply(self, region, pipeline_results) -> SymbolScopeDict: @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ScalarWriteShadowScopes(ppl.Pass): """ For each scalar or array of size 1, create a dictionary mapping writes to that data container to the set of reads @@ -651,7 +651,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class AccessRanges(ppl.Pass): """ For each data descriptor, finds all memlets used to access it (read/write ranges). @@ -689,7 +689,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FindReferenceSources(ppl.Pass): """ For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used @@ -778,7 +778,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class DeriveSDFGConstraints(ppl.Pass): CATEGORY: str = 'Analysis' @@ -813,7 +813,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Tuple[Dict[str, Set[str]], Dict[str, Set[ return {}, invariants, {} -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StatePropagation(ppl.ControlFlowRegionPass): """ Analyze a control flow region to determine the number of times each block inside of it is executed in the form of a diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index 6681ed6da0..fd472336e0 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -13,7 +13,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ArrayElimination(ppl.Pass): """ Merges and removes arrays and their corresponding accesses. This includes redundant array copies, unnecessary views, diff --git a/dace/transformation/passes/consolidate_edges.py b/dace/transformation/passes/consolidate_edges.py index 5b1aae2621..94cd29b6ae 100644 --- a/dace/transformation/passes/consolidate_edges.py +++ b/dace/transformation/passes/consolidate_edges.py @@ -5,11 +5,11 @@ from dace import SDFG, properties from typing import Optional -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class ConsolidateEdges(ppl.Pass): """ Removes extraneous edges with memlets that refer to the same data containers within the same scope. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 92183451a2..24c35edcc9 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -24,7 +24,7 @@ class _UnknownValue: @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ConstantPropagation(ppl.Pass): """ Propagates constants and symbols that were assigned to one value forward through the SDFG, reducing diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 9bc8e27dd2..908150d5e2 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -20,7 +20,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class DeadDataflowElimination(ppl.ControlFlowRegionPass): """ Removes unused computations from SDFG states. diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index e9c622d128..cc7e262e4d 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -13,7 +13,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class DeadStateElimination(ppl.Pass): """ Removes all unreachable states (e.g., due to a branch that will never be taken) from an SDFG. diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 13583998ac..a873bf0888 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -11,12 +11,12 @@ from dace.sdfg.state import ConditionalBlock, FunctionCallRegion, LoopRegion, NamedRegion from dace.sdfg.utils import fuse_states, inline_control_flow_regions, inline_sdfgs from dace.transformation import pass_pipeline as ppl -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @dataclass(unsafe_hash=True) @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class FuseStates(ppl.Pass): """ Fuses all possible states of an SDFG (and all sub-SDFGs). @@ -53,7 +53,7 @@ def report(self, pass_retval: int) -> str: @dataclass(unsafe_hash=True) @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class InlineSDFGs(ppl.Pass): """ Inlines all possible nested SDFGs (and sub-SDFGs). @@ -91,7 +91,7 @@ def report(self, pass_retval: int) -> str: @dataclass(unsafe_hash=True) @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class InlineControlFlowRegions(ppl.Pass): """ Inlines all control flow regions. @@ -168,7 +168,7 @@ def set_opts(self, opts): @dataclass(unsafe_hash=True) @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class FixNestedSDFGReferences(ppl.Pass): """ Fixes nested SDFG references to parent state/SDFG/node diff --git a/dace/transformation/passes/optional_arrays.py b/dace/transformation/passes/optional_arrays.py index 366231d1f1..6f96f0f53f 100644 --- a/dace/transformation/passes/optional_arrays.py +++ b/dace/transformation/passes/optional_arrays.py @@ -10,7 +10,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class OptionalArrayInference(ppl.Pass): """ Infers the ``optional`` property of arrays, i.e., if they can be given None, throughout the SDFG and all nested diff --git a/dace/transformation/passes/pattern_matching.py b/dace/transformation/passes/pattern_matching.py index faa011f7d9..2149f754a8 100644 --- a/dace/transformation/passes/pattern_matching.py +++ b/dace/transformation/passes/pattern_matching.py @@ -98,13 +98,13 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[str, # For every transformation in the list, find first match and apply for xform in self.transformations: - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(xform, '__experimental_cfg_block_compatible__') or - xform.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(xform, '__explicit_cf_compatible__') or + xform.__explicit_cf_compatible__ == False): warnings.warn('Pattern matching is skipping transformation ' + xform.__class__.__name__ + ' due to incompatibility with experimental control flow blocks. If the ' + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + - 'not have `SDFG.using_experimental_blocks` set to True. If ' + + 'not have `SDFG.using_explicit_control_flow` set to True. If ' + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + @@ -218,13 +218,13 @@ def _apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any], apply_once: while applied_anything: applied_anything = False for xform in xforms: - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(xform, '__experimental_cfg_block_compatible__') or - xform.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(xform, '__explicit_cf_compatible__') or + xform.__explicit_cf_compatible__ == False): warnings.warn('Pattern matching is skipping transformation ' + xform.__class__.__name__ + ' due to incompatibility with experimental control flow blocks. If the ' + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + - 'not have `SDFG.using_experimental_blocks` set to True. If ' + + 'not have `SDFG.using_explicit_control_flow` set to True. If ' + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + @@ -410,13 +410,13 @@ def _try_to_match_transformation(graph: Union[ControlFlowRegion, SDFGState], col for oname, oval in opts.items(): setattr(match, oname, oval) - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(match, '__experimental_cfg_block_compatible__') or - match.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(match, '__explicit_cf_compatible__') or + match.__explicit_cf_compatible__ == False): warnings.warn('Pattern matching is skipping transformation ' + match.__class__.__name__ + ' due to incompatibility with experimental control flow blocks. If the ' + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + - 'not have `SDFG.using_experimental_blocks` set to True. If ' + + 'not have `SDFG.using_explicit_control_flow` set to True. If ' + match.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index a8385b493f..a01d903a1d 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -12,7 +12,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class RemoveUnusedSymbols(ppl.Pass): """ Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``) and interstate edges. diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index a04cd89e77..6418a6025b 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -11,7 +11,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ReferenceToView(ppl.Pass): """ Replaces Reference data descriptors that are only set to one source with views. diff --git a/dace/transformation/passes/scalar_fission.py b/dace/transformation/passes/scalar_fission.py index 0b234f2961..8d88f2752b 100644 --- a/dace/transformation/passes/scalar_fission.py +++ b/dace/transformation/passes/scalar_fission.py @@ -8,7 +8,7 @@ from dace.transformation.passes import analysis as ap -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ScalarFission(ppl.Pass): """ Fission transient scalars or arrays of size 1 that are dominated by a write into separate data containers. diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 2fde6153ec..64e94a1e8a 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -24,7 +24,7 @@ from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.transformation import helpers as xfh from dace.transformation import pass_pipeline as passes -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible class AttributedCallDetector(ast.NodeVisitor): @@ -612,7 +612,7 @@ def translate_cpp_tasklet_to_python(code: str): @dataclass(unsafe_hash=True) @props.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class ScalarToSymbolPromotion(passes.Pass): CATEGORY: str = 'Simplification' diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 6ef900f28f..81b9c6b0eb 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -18,7 +18,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ControlFlowRaising(ppl.Pass): """ Raises all detectable control flow that can be expressed with native SDFG structures, such as loops and branching. @@ -175,7 +175,7 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) lifted = n_cond_regions_post - n_cond_regions_pre if lifted: - sdfg.root_sdfg.using_experimental_blocks = True + sdfg.root_sdfg.using_explicit_control_flow = True return lifted def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int, int]]: diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py index 1f44008351..d7bd397830 100644 --- a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -9,7 +9,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass): """ Prunes empty (or no-op) conditional branches from conditional blocks. diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 97eb383764..bfd22ebaf3 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -48,7 +48,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class SimplifyPass(ppl.FixedPointPipeline): """ A pipeline that simplifies an SDFG by applying a series of simplification passes. @@ -106,15 +106,15 @@ def apply_subpass(self, sdfg: SDFG, p: ppl.Pass, state: Dict[str, Any]): """ Apply a pass from the pipeline. This method is meant to be overridden by subclasses. """ - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(p, '__experimental_cfg_block_compatible__') or - p.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(p, '__explicit_cf_compatible__') or + p.__explicit_cf_compatible__ == False): warnings.warn(p.__class__.__name__ + ' is not being applied due to incompatibility with ' + 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + - 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + + 'ensure the top level SDFG does not have `SDFG.using_explicit_control_flow` set to ' + 'True. If ' + p.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') return None diff --git a/dace/transformation/passes/symbol_ssa.py b/dace/transformation/passes/symbol_ssa.py index 29dec5b861..da0d1cdbb1 100644 --- a/dace/transformation/passes/symbol_ssa.py +++ b/dace/transformation/passes/symbol_ssa.py @@ -8,7 +8,7 @@ from dace.transformation.passes import analysis as ap -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StrictSymbolSSA(ppl.ControlFlowRegionPass): """ Perform an SSA transformation on all symbols in the SDFG in a strict manner, i.e., without introducing phi nodes. diff --git a/dace/transformation/passes/transient_reuse.py b/dace/transformation/passes/transient_reuse.py index 805ddadff4..99e41d724f 100644 --- a/dace/transformation/passes/transient_reuse.py +++ b/dace/transformation/passes/transient_reuse.py @@ -6,11 +6,11 @@ from dace import SDFG, properties from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class TransientReuse(ppl.Pass): """ Reduces memory consumption by reusing allocated transient array memory. Only modifies arrays that can safely be diff --git a/dace/transformation/subgraph/composite.py b/dace/transformation/subgraph/composite.py index a7904e5108..fb09fcb51e 100644 --- a/dace/transformation/subgraph/composite.py +++ b/dace/transformation/subgraph/composite.py @@ -19,7 +19,7 @@ @make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class CompositeFusion(transformation.SubgraphTransformation): """ MultiExpansion + SubgraphFusion in one Transformation Additional StencilTiling is also possible as a canonicalizing diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 9b109bfcfb..8c11c5d200 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -34,8 +34,8 @@ PassT = TypeVar('PassT', bound=ppl.Pass) -def experimental_cfg_block_compatible(cls: PassT) -> PassT: - cls.__experimental_cfg_block_compatible__ = True +def explicit_cf_compatible(cls: PassT) -> PassT: + cls.__explicit_cf_compatible__ = True return cls @@ -506,7 +506,7 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Patt @make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class SingleStateTransformation(PatternTransformation, abc.ABC): """ Base class for pattern-matching transformations that find matches within a single SDFG state. @@ -1061,7 +1061,7 @@ def blocksafe_wrapper(tgt, *args, **kwargs): sdfg = get_sdfg_arg(tgt, *args) if sdfg and isinstance(sdfg, SDFG): root_sdfg: SDFG = sdfg.cfg_list[0] - if not root_sdfg.using_experimental_blocks: + if not root_sdfg.using_explicit_control_flow: return vanilla_method(tgt, *args, **kwargs) else: warnings.warn('Skipping ' + function_name + ' from ' + cls.__name__ + diff --git a/tests/fortran/fortran_loops_test.py b/tests/fortran/fortran_loops_test.py index 4d4c259f07..b18a5e36e8 100644 --- a/tests/fortran/fortran_loops_test.py +++ b/tests/fortran/fortran_loops_test.py @@ -29,7 +29,7 @@ def test_fortran_frontend_loop_region_basic_loop(): ENDDO end SUBROUTINE loop_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, use_experimental_cfg_blocks=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, use_explicit_cf=True) a_test = np.full([10, 10], 2, order="F", dtype=np.float64) b_test = np.full([10, 10], 3, order="F", dtype=np.float64) diff --git a/tests/python_frontend/conditional_regions_test.py b/tests/python_frontend/conditional_regions_test.py index 07e214653c..6a917a13c3 100644 --- a/tests/python_frontend/conditional_regions_test.py +++ b/tests/python_frontend/conditional_regions_test.py @@ -15,7 +15,7 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): return 10 return 100 - dataflow_if_check.use_experimental_cfg_blocks = True + dataflow_if_check.use_explicit_cf = True sdfg = dataflow_if_check.to_sdfg() assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -47,7 +47,7 @@ def nested_if_chain(i: dace.int64): else: return 4 - nested_if_chain.use_experimental_cfg_blocks = True + nested_if_chain.use_explicit_cf = True sdfg = nested_if_chain.to_sdfg() assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -74,7 +74,7 @@ def elif_chain(i: dace.int64): else: return 4 - elif_chain.use_experimental_cfg_blocks = True + elif_chain.use_explicit_cf = True sdfg = elif_chain.to_sdfg() assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) diff --git a/tests/python_frontend/loop_regions_test.py b/tests/python_frontend/loop_regions_test.py index cb7fa30fd4..1047f770da 100644 --- a/tests/python_frontend/loop_regions_test.py +++ b/tests/python_frontend/loop_regions_test.py @@ -15,7 +15,7 @@ def for_loop(): def test_for_loop(): - for_loop.use_experimental_cfg_blocks = True + for_loop.use_explicit_cf = True sdfg = for_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -39,7 +39,7 @@ def for_loop_with_break_continue(): def test_for_loop_with_break_continue(): - for_loop_with_break_continue.use_experimental_cfg_blocks = True + for_loop_with_break_continue.use_explicit_cf = True sdfg = for_loop_with_break_continue.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -68,7 +68,7 @@ def nested_for_loop(): def test_nested_for_loop(): - nested_for_loop.use_experimental_cfg_blocks = True + nested_for_loop.use_explicit_cf = True sdfg = nested_for_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -92,7 +92,7 @@ def while_loop(): def test_while_loop(): - while_loop.use_experimental_cfg_blocks = True + while_loop.use_explicit_cf = True sdfg = while_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -118,7 +118,7 @@ def while_loop_with_break_continue(): def test_while_loop_with_break_continue(): - while_loop_with_break_continue.use_experimental_cfg_blocks = True + while_loop_with_break_continue.use_explicit_cf = True sdfg = while_loop_with_break_continue.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -151,7 +151,7 @@ def nested_while_loop(): def test_nested_while_loop(): - nested_while_loop.use_experimental_cfg_blocks = True + nested_while_loop.use_explicit_cf = True sdfg = nested_while_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -184,7 +184,7 @@ def nested_for_while_loop(): def test_nested_for_while_loop(): - nested_for_while_loop.use_experimental_cfg_blocks = True + nested_for_while_loop.use_explicit_cf = True sdfg = nested_for_while_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -217,7 +217,7 @@ def nested_while_for_loop(): def test_nested_while_for_loop(): - nested_while_for_loop.use_experimental_cfg_blocks = True + nested_while_for_loop.use_explicit_cf = True sdfg = nested_while_for_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -244,7 +244,7 @@ def map_with_break_continue(): def test_map_with_break_continue(): try: - map_with_break_continue.use_experimental_cfg_blocks = True + map_with_break_continue.use_explicit_cf = True map_with_break_continue() except Exception as e: if isinstance(e, DaceSyntaxError): @@ -266,7 +266,7 @@ def test_nested_map_for_loop(): for i in range(10): for j in range(10): ref[i, j] = i * 10 + j - nested_map_for_loop.use_experimental_cfg_blocks = True + nested_map_for_loop.use_explicit_cf = True val = nested_map_for_loop() assert (np.array_equal(val, ref)) @@ -287,7 +287,7 @@ def test_nested_map_for_for_loop(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_map_for_for_loop.use_experimental_cfg_blocks = True + nested_map_for_for_loop.use_explicit_cf = True val = nested_map_for_for_loop() assert (np.array_equal(val, ref)) @@ -308,7 +308,7 @@ def test_nested_for_map_for_loop(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_for_map_for_loop.use_experimental_cfg_blocks = True + nested_for_map_for_loop.use_explicit_cf = True val = nested_for_map_for_loop() assert (np.array_equal(val, ref)) @@ -332,7 +332,7 @@ def test_nested_map_for_loop_with_tasklet(): for i in range(10): for j in range(10): ref[i, j] = i * 10 + j - nested_map_for_loop_with_tasklet.use_experimental_cfg_blocks = True + nested_map_for_loop_with_tasklet.use_explicit_cf = True val = nested_map_for_loop_with_tasklet() assert (np.array_equal(val, ref)) @@ -358,7 +358,7 @@ def test_nested_map_for_for_loop_with_tasklet(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_map_for_for_loop_with_tasklet.use_experimental_cfg_blocks = True + nested_map_for_for_loop_with_tasklet.use_explicit_cf = True val = nested_map_for_for_loop_with_tasklet() assert (np.array_equal(val, ref)) @@ -384,7 +384,7 @@ def test_nested_for_map_for_loop_with_tasklet(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_for_map_for_loop_with_tasklet.use_experimental_cfg_blocks = True + nested_for_map_for_loop_with_tasklet.use_explicit_cf = True val = nested_for_map_for_loop_with_tasklet() assert (np.array_equal(val, ref)) @@ -404,7 +404,7 @@ def test_nested_map_for_loop_2(): for i in range(10): for j in range(10): ref[i, j] = 2 + i * 10 + j - nested_map_for_loop_2.use_experimental_cfg_blocks = True + nested_map_for_loop_2.use_explicit_cf = True val = nested_map_for_loop_2(B) assert (np.array_equal(val, ref)) @@ -430,7 +430,7 @@ def test_nested_map_for_loop_with_tasklet_2(): for i in range(10): for j in range(10): ref[i, j] = 2 + i * 10 + j - nested_map_for_loop_with_tasklet_2.use_experimental_cfg_blocks = True + nested_map_for_loop_with_tasklet_2.use_explicit_cf = True val = nested_map_for_loop_with_tasklet_2(B) assert (np.array_equal(val, ref)) @@ -449,7 +449,7 @@ def test_nested_map_with_symbol(): for i in range(10): for j in range(i, 10): ref[i, j] = i * 10 + j - nested_map_with_symbol.use_experimental_cfg_blocks = True + nested_map_with_symbol.use_explicit_cf = True val = nested_map_with_symbol() assert (np.array_equal(val, ref)) @@ -477,7 +477,7 @@ def for_else(A: dace.float64[20]): for_else.f(expected_1) for_else.f(expected_2) - for_else.use_experimental_cfg_blocks = True + for_else.use_explicit_cf = True for_else(A) assert np.allclose(A, expected_1) @@ -500,7 +500,7 @@ def while_else(A: dace.float64[2]): A[1] = 1.0 A[1] = 1.0 - while_else.use_experimental_cfg_blocks = True + while_else.use_explicit_cf = True A = np.array([0.0, 0.0]) expected = np.array([5.0, 1.0]) @@ -523,7 +523,7 @@ def branch_in_for(cond: dace.int32): def test_branch_in_for(): - branch_in_for.use_experimental_cfg_blocks = True + branch_in_for.use_explicit_cf = True sdfg = branch_in_for.to_sdfg(simplify=False) assert len(sdfg.source_nodes()) == 1 @@ -540,7 +540,7 @@ def branch_in_while(cond: dace.int32): def test_branch_in_while(): - branch_in_while.use_experimental_cfg_blocks = True + branch_in_while.use_explicit_cf = True sdfg = branch_in_while.to_sdfg(simplify=False) assert len(sdfg.source_nodes()) == 1 @@ -553,7 +553,7 @@ def for_with_return(A: dace.int32[10]): return 1 return 0 - for_with_return.use_experimental_cfg_blocks = True + for_with_return.use_explicit_cf = True sdfg = for_with_return.to_sdfg() A = np.full((10,), 1).astype(np.int32) @@ -578,7 +578,7 @@ def for_while_with_return(A: dace.int32[10, 10]): j += 1 return 0 - for_while_with_return.use_experimental_cfg_blocks = True + for_while_with_return.use_explicit_cf = True sdfg = for_while_with_return.to_sdfg() A = np.full((10,10), 1).astype(np.int32) diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index dedafb67ba..86b84851e2 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -8,7 +8,7 @@ def _make_regular_for_loop() -> SDFG: sdfg = dace.SDFG('regular_for') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0', update_expr='i = i + 1', inverted=False) @@ -27,7 +27,7 @@ def _make_regular_for_loop() -> SDFG: def _make_regular_while_loop() -> SDFG: sdfg = dace.SDFG('regular_while') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10') sdfg.add_array('A', [10], dace.float32) @@ -47,7 +47,7 @@ def _make_regular_while_loop() -> SDFG: def _make_do_while_loop() -> SDFG: sdfg = dace.SDFG('do_while') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', inverted=True) @@ -67,7 +67,7 @@ def _make_do_while_loop() -> SDFG: def _make_do_for_loop() -> SDFG: sdfg = dace.SDFG('do_for') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [10], dace.float32) state0 = sdfg.add_state('state0', is_start_block=True) @@ -88,7 +88,7 @@ def _make_do_for_loop() -> SDFG: def _make_do_for_inverted_cond_loop() -> SDFG: sdfg = dace.SDFG('do_for_inverted_cond') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [10], dace.float32) state0 = sdfg.add_state('state0', is_start_block=True) @@ -109,7 +109,7 @@ def _make_do_for_inverted_cond_loop() -> SDFG: def _make_triple_nested_for_loop() -> SDFG: sdfg = dace.SDFG('gemm') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) sdfg.add_symbol('j', dace.int32) sdfg.add_symbol('k', dace.int32) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 7c86f454c6..9465900657 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -222,7 +222,7 @@ def test_work_depth(test_name): # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. inline_control_flow_regions(sdfg) for sd in sdfg.all_sdfgs_recursive(): - sd.using_experimental_blocks = False + sd.using_explicit_control_flow = False analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index 20f244621c..676512f5f6 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -45,7 +45,7 @@ def test_lift_regular_for_loop(): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) @@ -101,7 +101,7 @@ def test_lift_loop_llvm_canonical(increment_before_condition): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) @@ -158,7 +158,7 @@ def test_lift_loop_llvm_canonical_while(): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) @@ -201,7 +201,7 @@ def test_do_while(): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index b7c1056162..27ab7a3660 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -19,7 +19,7 @@ def tester(a: dace.float64[20]): for i in range(1, 20): a[i] = a[i - 1] + 1 - tester.use_experimental_cfg_blocks = False + tester.use_explicit_cf = False sdfg = tester.to_sdfg(simplify=False) xform = CountLoops() assert sdfg.apply_transformations(xform) == 1 From e2a6466f883b8900dd9f4bb43ea052967a7e5f13 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 11 Dec 2024 15:34:57 +0100 Subject: [PATCH 105/108] Added more extensible meta access replacement function --- dace/sdfg/replace.py | 13 ++-------- dace/sdfg/state.py | 23 +++++++++++++++++ .../interstate/gpu_transform_sdfg.py | 25 ++----------------- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index b49f13cee6..cab313fc9b 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -11,7 +11,7 @@ from dace import dtypes, properties, symbolic from dace.codegen import cppunparse from dace.frontend.python.astutils import ASTFindReplace -from dace.sdfg.state import ConditionalBlock, LoopRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion if TYPE_CHECKING: from dace.sdfg.state import StateSubgraphView @@ -203,13 +203,4 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]): edge.data.data = repl[edge.data.data] # Replace in loop or branch conditions: - if isinstance(cf, LoopRegion): - replace_in_codeblock(cf.loop_condition, repl) - if cf.update_statement: - replace_in_codeblock(cf.update_statement, repl) - if cf.init_statement: - replace_in_codeblock(cf.init_statement, repl) - elif isinstance(cf, ConditionalBlock): - for c, _ in cf.branches: - if c is not None: - replace_in_codeblock(c, repl) + cf.replace_meta_accesses(repl) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 5e5d07b288..fbc157f74d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -13,6 +13,7 @@ import dace from dace.frontend.python import astutils +from dace.sdfg.replace import replace_in_codeblock import dace.serialize from dace import data as dt from dace import dtypes @@ -2612,6 +2613,16 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: """ return [] + def replace_meta_accesses(self, replacements: dict) -> None: + """ + Replace accesses to specific data containers in reads or writes performed by the control flow region itself in + meta accesses, such as in condition checks for conditional blocks or in loop conditions for loops, etc. + + :param replacements: A dictionary mapping the current data container names to the names of data containers with + which accesses to them should be replaced. + """ + pass + @property def root_sdfg(self) -> 'SDFG': from dace.sdfg.sdfg import SDFG # Avoid import loop @@ -3304,6 +3315,13 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays)) return read_memlets + def replace_meta_accesses(self, replacements): + replace_in_codeblock(self.loop_condition, replacements) + if self.init_statement: + replace_in_codeblock(self.init_statement, replacements) + if self.update_statement: + replace_in_codeblock(self.update_statement, replacements) + def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, @@ -3418,6 +3436,11 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optio def sub_regions(self): return [b for _, b in self.branches] + def replace_meta_accesses(self, replacements): + for c, _ in self.branches: + if c is not None: + replace_in_codeblock(c, replacements) + def __str__(self): return self._label diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 901b05cb64..49a2e16227 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -598,35 +598,14 @@ def _create_copy_out(arrays_used: Set[str]) -> Dict[str, str]: e.data.replace(devicename, hostname, False) for block in list(sdfg.all_control_flow_blocks()): - arrays_used = set() - if isinstance(block, ConditionalBlock): - for c, _ in block.branches: - if c is not None: - arrays_used.update(set(c.get_free_symbols()) & cloned_data) - elif isinstance(block, LoopRegion): - arrays_used.update(set(block.loop_condition.get_free_symbols()) & cloned_data) - if block.init_statement: - arrays_used.update(set(block.init_statement.get_free_symbols()) & cloned_data) - if block.update_statement: - arrays_used.update(set(block.update_statement.get_free_symbols()) & cloned_data) - else: - continue + arrays_used = set(block.used_symbols(all_symbols=True, with_contents=False)) & cloned_data # Create a state and copy out used arrays if len(arrays_used) > 0: co_state = block.parent_graph.add_state_before(block, block.label + '_icopyout') mapping = _create_copy_out(arrays_used) for devicename, hostname in mapping.items(): - if isinstance(block, ConditionalBlock): - for c, _ in block.branches: - if c is not None: - replace_in_codeblock(c, {devicename: hostname}) - elif isinstance(block, LoopRegion): - replace_in_codeblock(block.loop_condition, {devicename: hostname}) - if block.init_statement: - replace_in_codeblock(block.init_statement, {devicename: hostname}) - if block.update_statement: - replace_in_codeblock(block.update_statement, {devicename: hostname}) + block.replace_meta_accesses({devicename: hostname}) # Step 9: Simplify if not self.simplify: From 6f82fc6d65d5fa2ce58db43f2fc4d1ac373cd4a9 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 11 Dec 2024 15:49:56 +0100 Subject: [PATCH 106/108] Fixes --- dace/sdfg/replace.py | 1 - dace/transformation/passes/pattern_matching.py | 6 +++--- dace/transformation/passes/scalar_to_symbol.py | 13 +------------ .../subgraph/gpu_persistent_fusion.py | 10 +++++----- doc/frontend/parsing.rst | 8 ++++---- tests/sdfg/work_depth_test.py | 2 +- 6 files changed, 14 insertions(+), 26 deletions(-) diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index cab313fc9b..2f9ead4dcd 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -11,7 +11,6 @@ from dace import dtypes, properties, symbolic from dace.codegen import cppunparse from dace.frontend.python.astutils import ASTFindReplace -from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion if TYPE_CHECKING: from dace.sdfg.state import StateSubgraphView diff --git a/dace/transformation/passes/pattern_matching.py b/dace/transformation/passes/pattern_matching.py index 2149f754a8..7aa16633fd 100644 --- a/dace/transformation/passes/pattern_matching.py +++ b/dace/transformation/passes/pattern_matching.py @@ -107,7 +107,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[str, 'not have `SDFG.using_explicit_control_flow` set to True. If ' + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') continue @@ -227,7 +227,7 @@ def _apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any], apply_once: 'not have `SDFG.using_explicit_control_flow` set to True. If ' + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') continue @@ -419,7 +419,7 @@ def _try_to_match_transformation(graph: Union[ControlFlowRegion, SDFGState], col 'not have `SDFG.using_explicit_control_flow` set to True. If ' + match.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') return None diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 64e94a1e8a..a8e02dacd3 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -235,18 +235,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer for edge in sdfg.all_interstate_edges(): interstate_symbols |= edge.data.free_symbols for reg in sdfg.all_control_flow_regions(): - if isinstance(reg, LoopRegion): - interstate_symbols |= reg.loop_condition.get_free_symbols() - if reg.loop_variable: - interstate_symbols.add(reg.loop_variable) - if reg.update_statement: - interstate_symbols |= reg.update_statement.get_free_symbols() - if reg.init_statement: - interstate_symbols |= reg.init_statement.get_free_symbols() - elif isinstance(reg, ConditionalBlock): - for c, _ in reg.branches: - if c is not None: - interstate_symbols |= c.get_free_symbols() + interstate_symbols |= reg.used_symbols(all_symbols=True, with_contents=False) for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[candidate].dtype not in dtypes.INTEGER_TYPES: candidates.remove(candidate) diff --git a/dace/transformation/subgraph/gpu_persistent_fusion.py b/dace/transformation/subgraph/gpu_persistent_fusion.py index b7c201a3d7..a9a75d6bc7 100644 --- a/dace/transformation/subgraph/gpu_persistent_fusion.py +++ b/dace/transformation/subgraph/gpu_persistent_fusion.py @@ -194,11 +194,11 @@ def apply(self, sdfg: SDFG): if k in sdfg.symbols and k not in kernel_sdfg.symbols: kernel_sdfg.add_symbol(k, sdfg.symbols[k]) for blk in subgraph_blocks: - if isinstance(blk, LoopRegion): - if blk.loop_variable and blk.init_statement: - new_symbols.add(blk.loop_variable) - if blk.loop_variable in sdfg.symbols and blk.loop_variable not in kernel_sdfg.symbols: - kernel_sdfg.add_symbol(blk.loop_variable, sdfg.symbols[blk.loop_variable]) + if isinstance(blk, AbstractControlFlowRegion): + for k, v in blk.new_symbols(sdfg.symbols).items(): + new_symbols.add(k) + if k not in kernel_sdfg.symbols: + kernel_sdfg.add_symbol(k, v) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: diff --git a/doc/frontend/parsing.rst b/doc/frontend/parsing.rst index 7adc415497..d909cd7deb 100644 --- a/doc/frontend/parsing.rst +++ b/doc/frontend/parsing.rst @@ -169,7 +169,7 @@ Example: :alt: Generated SDFG for-loop for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, this will utilize +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, this will utilize :class:`~dace.sdfg.state.LoopRegion`s instead of the explicit state machine depicted above. :func:`~dace.frontend.python.newast.ProgramVisitor.visit_While` @@ -191,7 +191,7 @@ Parses `while `_ statement :alt: Generated SDFG while-loop for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, this will utilize +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, this will utilize :class:`~dace.sdfg.state.LoopRegion`s instead of the explicit state machine depicted above. :func:`~dace.frontend.python.newast.ProgramVisitor.visit_Break` @@ -214,7 +214,7 @@ behaves as an if-else statement. This is also evident from the generated dataflo :alt: Generated SDFG for-loop with a break statement for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, loops are +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, loops are represented with :class:`~dace.sdfg.state.LoopRegion`s, and a break is represented with a special :class:`~dace.sdfg.state.LoopRegion.BreakState`. @@ -238,7 +238,7 @@ of `continue` makes the ``A[i] = i`` statement unreachable. This is also evident :alt: Generated SDFG for-loop with a continue statement for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, loops are +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, loops are represented with :class:`~dace.sdfg.state.LoopRegion`s, and a continue is represented with a special :class:`~dace.sdfg.state.LoopRegion.ContinueState`. diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 9465900657..5ecda1cb88 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -275,7 +275,7 @@ def test_avg_par(test_name: str): # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. inline_control_flow_regions(sdfg) for sd in sdfg.all_sdfgs_recursive(): - sd.using_experimental_blocks = False + sd.using_explicit_control_flow = False analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] From 34a32471005a786dd7d94849129b337d841941c0 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 11 Dec 2024 16:20:44 +0100 Subject: [PATCH 107/108] Add more API methods --- dace/sdfg/state.py | 22 +++++ dace/sdfg/utils.py | 83 +++++++++---------- .../transformation/passes/scalar_to_symbol.py | 18 +--- 3 files changed, 66 insertions(+), 57 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index fbc157f74d..9ff5d6ad3f 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2606,6 +2606,13 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, self._cached_start_block: Optional[ControlFlowBlock] = None self._cfg_list: List['ControlFlowRegion'] = [self] + def get_meta_codeblocks(self) -> List[CodeBlock]: + """ + Get a list of codeblocks used by the control flow region. + This may include things such as loop control statements or conditions for branching etc. + """ + return [] + def get_meta_read_memlets(self) -> List[mm.Memlet]: """ Get read memlets used by the control flow region itself, such as in condition checks for conditional blocks, or @@ -3305,6 +3312,14 @@ def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: return True, (init_state, guard_state, end_state) + def get_meta_codeblocks(self): + codes = [self.loop_condition] + if self.init_statement: + codes.append(self.init_statement) + if self.update_statement: + codes.append(self.update_statement) + return codes + def get_meta_read_memlets(self) -> List[mm.Memlet]: # Avoid cyclic imports. from dace.sdfg.sdfg import memlets_in_ast @@ -3459,6 +3474,13 @@ def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): def remove_branch(self, branch: ControlFlowRegion): self._branches = [(c, b) for c, b in self._branches if b is not branch] + def get_meta_codeblocks(self): + codes = [] + for c, _ in self.branches: + if c is not None: + codes.append(c) + return codes + def get_meta_read_memlets(self) -> List[mm.Memlet]: # Avoid cyclic imports. from dace.sdfg.sdfg import memlets_in_ast diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index dddbfb7652..26b6629a81 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1557,51 +1557,48 @@ def _traverse(scope: Node, symbols: Dict[str, dtypes.typeclass]): def _tswds_cf_region( sdfg: SDFG, - region: AbstractControlFlowRegion, + cfg: AbstractControlFlowRegion, symbols: Dict[str, dtypes.typeclass], recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: - if isinstance(region, ConditionalBlock): - for _, b in region.branches: - yield from _tswds_cf_region(sdfg, b, symbols, recursive) - return - - # Add the own loop variable to the defined symbols, if present. - loop_syms = region.new_symbols(symbols) - symbols.update({k: v for k, v in loop_syms.items() if v is not None}) - - # Add symbols from inter-state edges along the state machine - start_region = region.start_block - visited = set() - visited_edges = set() - for edge in region.dfs_edges(start_region): - # Source -> inter-state definition -> Destination - visited_edges.add(edge) - # Source - if edge.src not in visited: - visited.add(edge.src) - if isinstance(edge.src, SDFGState): - yield from _tswds_state(sdfg, edge.src, {}, recursive) - elif isinstance(edge.src, AbstractControlFlowRegion): - yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) - - # Add edge symbols into defined symbols - issyms = edge.data.new_symbols(sdfg, symbols) - symbols.update({k: v for k, v in issyms.items() if v is not None}) - - # Destination - if edge.dst not in visited: - visited.add(edge.dst) - if isinstance(edge.dst, SDFGState): - yield from _tswds_state(sdfg, edge.dst, symbols, recursive) - elif isinstance(edge.dst, AbstractControlFlowRegion): - yield from _tswds_cf_region(sdfg, edge.dst, symbols, recursive) - - # If there is only one state, the DFS will miss it - if start_region not in visited: - if isinstance(start_region, SDFGState): - yield from _tswds_state(sdfg, start_region, symbols, recursive) - elif isinstance(start_region, AbstractControlFlowRegion): - yield from _tswds_cf_region(sdfg, start_region, symbols, recursive) + sub_regions = cfg.sub_regions() or [cfg] + for region in sub_regions: + # Add symbols newly defined by this region, if present. + region_symbols = region.new_symbols(symbols) + symbols.update({k: v for k, v in region_symbols.items() if v is not None}) + + # Add symbols from inter-state edges along the state machine + start_region = region.start_block + visited = set() + visited_edges = set() + for edge in region.dfs_edges(start_region): + # Source -> inter-state definition -> Destination + visited_edges.add(edge) + # Source + if edge.src not in visited: + visited.add(edge.src) + if isinstance(edge.src, SDFGState): + yield from _tswds_state(sdfg, edge.src, {}, recursive) + elif isinstance(edge.src, AbstractControlFlowRegion): + yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) + + # Add edge symbols into defined symbols + issyms = edge.data.new_symbols(sdfg, symbols) + symbols.update({k: v for k, v in issyms.items() if v is not None}) + + # Destination + if edge.dst not in visited: + visited.add(edge.dst) + if isinstance(edge.dst, SDFGState): + yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + elif isinstance(edge.dst, AbstractControlFlowRegion): + yield from _tswds_cf_region(sdfg, edge.dst, symbols, recursive) + + # If there is only one state, the DFS will miss it + if start_region not in visited: + if isinstance(start_region, SDFGState): + yield from _tswds_state(sdfg, start_region, symbols, recursive) + elif isinstance(start_region, AbstractControlFlowRegion): + yield from _tswds_cf_region(sdfg, start_region, symbols, recursive) def traverse_sdfg_with_defined_symbols( diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index a8e02dacd3..43cd45146d 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -739,20 +739,10 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: assignment = cleanup_re[scalar].sub(scalar, assignment.strip()) ise.assignments[aname] = assignment for reg in sdfg.all_control_flow_regions(): - if isinstance(reg, LoopRegion): - codes = [reg.loop_condition] - if reg.init_statement: - codes.append(reg.init_statement) - if reg.update_statement: - codes.append(reg.update_statement) - for cd in codes: - for stmt in cd.code: - promo.visit(stmt) - elif isinstance(reg, ConditionalBlock): - for c, _ in reg.branches: - if c is not None: - for stmt in c.code: - promo.visit(stmt) + meta_codes = reg.get_meta_codeblocks() + for cd in meta_codes: + for stmt in cd.code: + promo.visit(stmt) # Step 7: Indirection remove_symbol_indirection(sdfg) From fad34249f06d67819ed8ad8bff9026ea9f2fe670 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 12 Dec 2024 09:01:21 +0100 Subject: [PATCH 108/108] Address comments --- dace/frontend/python/interface.py | 2 +- dace/sdfg/state.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index aa8caef826..6fb92077b7 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -69,7 +69,7 @@ def program(f: F, not depend on internal variables are constant. This will hardcode their return values into the resulting program. - :param use_explicit_cfl: If True, makes use of explicit control flow constructs. + :param use_explicit_cf: If True, makes use of explicit control flow constructs. :note: If arguments are defined with type hints, the program can be compiled ahead-of-time with ``.compile()``. """ diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 9ff5d6ad3f..30640306cd 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2585,14 +2585,14 @@ class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.Inte ControlFlowBlock, abc.ABC): """ Abstract superclass to represent all kinds of control flow regions in an SDFG. - This is consequently one of the three main classes of control flow graph nodes, which include `ControlFlowBlock`s, - `SDFGState`s, and nested `AbstractControlFlowRegion`s. An `AbstractControlFlowRegion` can further be either a region - that directly contains a control flow graph (`ControlFlowRegion`s and subclasses thereof), or something that acts - like and has the same utilities as a control flow region, including the same API, but is itself not directly a - single graph. An example of this is the `ConditionalBlock`, which acts as a single control flow region to the - outside, but contains multiple actual graphs (one per branch). As such, there are very few but important differences - between the subclasses of `AbstractControlFlowRegion`s, such as how traversals are performed, how many start blocks - there are, etc. + This is consequently one of the three main classes of control flow graph nodes, which include ``ControlFlowBlock``s, + ``SDFGState``s, and nested ``AbstractControlFlowRegion``s. An ``AbstractControlFlowRegion`` can further be either a + region that directly contains a control flow graph (``ControlFlowRegion``s and subclasses thereof), or something + that acts like and has the same utilities as a control flow region, including the same API, but is itself not + directly a single graph. An example of this is the ``ConditionalBlock``, which acts as a single control flow region + to the outside, but contains multiple actual graphs (one per branch). As such, there are very few but important + differences between the subclasses of ``AbstractControlFlowRegion``s, such as how traversals are performed, how many + start blocks there are, etc. """ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, @@ -2753,7 +2753,7 @@ def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: return False, None - def new_symbols(self, symbols: dict) -> Dict[str, dtypes.typeclass]: + def new_symbols(self, symbols: Dict[str, dtypes.typeclass]) -> Dict[str, dtypes.typeclass]: """ Returns a mapping between the symbol defined by this control flow region and its type, if it exists. """ @@ -3099,7 +3099,7 @@ def start_block(self, block_id): @make_properties class ControlFlowRegion(AbstractControlFlowRegion): """ - A `ControlFlowRegion` represents a control flow graph node that itself contains a control flow graph. + A ``ControlFlowRegion`` represents a control flow graph node that itself contains a control flow graph. This can be an arbitrary control flow graph, but may also be a specific type of control flow region with additional semantics, such as a loop or a function call. """