From e82870a5e0485f09eda31e6ba413491838308540 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Mon, 16 Dec 2024 13:59:27 +0100 Subject: [PATCH] Let LiftStructViews lift interstate edge struct accesses to views (#1827) `LiftStructViews` is a necessary pass for ensuring structure member accesses and container array accesses work correctly, by lifting accesses to them into a 'tower of views' that correctly traces their accesses to the root data container. However, this previously did not consider interstate edges. This PR fixes this by making two changes: - Allow `LiftStructViews` to lift struct accesses on interstate edges to the correct tower of views, by inserting a lifting state right before the interstate edge, where views are constructed. The interstate edge is then re-written to access the correct view instead. - Fix `ArrayElimination` to correctly handle these views, which are not directly connected to other nodes, by preventing it from incorrectly merging or removing them. This is done by allowing source/sink view nodes to be merged correctly, i.e., when they are 'the same view', i.e., their view memlets are identical. This shows an example of the constructed tower of views for a complex interstate edge struct access: ![image](https://github.com/user-attachments/assets/bc8173fe-a6a9-4592-97d9-f1b02a6da740) --- dace/sdfg/utils.py | 6 +- .../passes/array_elimination.py | 60 ++++- .../passes/lift_struct_views.py | 245 ++++++++++++++++-- 3 files changed, 276 insertions(+), 35 deletions(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 26b6629a81..8160b1de72 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -867,16 +867,16 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg return None in_edge = in_edges[0] - out_edge = out_edges[0] + out_edge = out_edges[0] if len(out_edges) > 0 else None # If there is one incoming and one outgoing edge, and one leads to a code # node, the one that leads to an access node is the viewed data. inmpath = state.memlet_path(in_edge) - outmpath = state.memlet_path(out_edge) + outmpath = state.memlet_path(out_edge) if out_edge else None src_is_data, dst_is_data = False, False if isinstance(inmpath[0].src, nd.AccessNode): src_is_data = True - if isinstance(outmpath[-1].dst, nd.AccessNode): + if outmpath and isinstance(outmpath[-1].dst, nd.AccessNode): dst_is_data = True if src_is_data and not dst_is_data: diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index fd472336e0..803c81b21e 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -1,10 +1,13 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Set from dace import SDFG, SDFGState, data, properties +from dace.memlet import Memlet from dace.sdfg import nodes from dace.sdfg.analysis import cfg +from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.validation import InvalidSDFGNodeError from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.dataflow import (RedundantArray, RedundantReadSlice, RedundantSecondArray, RedundantWriteSlice, SqueezeViewRemove, UnsqueezeViewRemove, RemoveSliceView) @@ -66,9 +69,6 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S removed_nodes = self.merge_access_nodes(state, access_nodes, lambda n: state.in_degree(n) == 0) removed_nodes |= self.merge_access_nodes(state, access_nodes, lambda n: state.out_degree(n) == 0) - # Update access nodes with merged nodes - access_nodes = {k: [n for n in v if n not in removed_nodes] for k, v in access_nodes.items()} - # Remove redundant views removed_nodes |= self.remove_redundant_views(sdfg, state, access_nodes) @@ -105,25 +105,59 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node Merges access nodes that follow the same conditions together to the first access node. """ removed_nodes: Set[nodes.AccessNode] = set() - for nodeset in access_nodes.values(): + for data_container in access_nodes.keys(): + nodeset = access_nodes[data_container] if len(nodeset) > 1: - # Merge all other access nodes to the first one - first_node = nodeset[0] - if not condition(first_node): + # Merge all other access nodes to the first one that fits the condition, if one exists. + first_node = None + first_node_idx = 0 + for i, node in enumerate(nodeset[:-1]): + if condition(node): + first_node = node + first_node_idx = i + break + if first_node is None: continue - for node in nodeset[1:]: + + for node in nodeset[first_node_idx + 1:]: if not condition(node): continue - # Reconnect edges to first node - for edge in state.all_edges(node): + # Reconnect edges to first node. + # If we are handling views, we do not want to add more than one edge going into a 'views' connector, + # so we only merge nodes if the memlets match exactly (which they should). But in that case without + # copying the edge. + edges: List[MultiConnectorEdge[Memlet]] = state.all_edges(node) + other_edges: List[MultiConnectorEdge[Memlet]] = [] + for edge in edges: if edge.dst is node: - state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data) + if edge.dst_conn == 'views': + other_edges = list(state.in_edges_by_connector(first_node, 'views')) + if len(other_edges) != 1: + raise InvalidSDFGNodeError('Multiple edges connected to views connector', + state.sdfg, state.block_id, state.node_id(first_node)) + other_view_edge = other_edges[0] + if other_view_edge.data != edge.data: + # The memlets do not match, skip the node. + continue + else: + state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data) else: - state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) + if edge.src_conn == 'views': + other_edges = list(state.out_edges_by_connector(first_node, 'views')) + if len(other_edges) != 1: + raise InvalidSDFGNodeError('Multiple edges connected to views connector', + state.sdfg, state.block_id, state.node_id(first_node)) + other_view_edge = other_edges[0] + if other_view_edge.data != edge.data: + # The memlets do not match, skip the node. + continue + else: + state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data) # Remove merged node and associated edges state.remove_node(node) removed_nodes.add(node) + access_nodes[data_container] = [n for n in nodeset if n not in removed_nodes] return removed_nodes def remove_redundant_views(self, sdfg: SDFG, state: SDFGState, access_nodes: Dict[str, List[nodes.AccessNode]]): diff --git a/dace/transformation/passes/lift_struct_views.py b/dace/transformation/passes/lift_struct_views.py index 619a86d3ed..6744161000 100644 --- a/dace/transformation/passes/lift_struct_views.py +++ b/dace/transformation/passes/lift_struct_views.py @@ -1,12 +1,15 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from collections import defaultdict -from typing import Any, Dict, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from dace import SDFG, Memlet, SDFGState from dace.frontend.python import astutils +from dace.properties import CodeBlock from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.graph import Edge, MultiConnectorEdge +from dace.sdfg.sdfg import InterstateEdge, memlets_in_ast +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.transformation import pass_pipeline as ppl from dace import data as dt from dace import dtypes @@ -187,6 +190,163 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: else: raise NotImplementedError() +class InterstateEdgeRecoder(ast.NodeTransformer): + + sdfg: SDFG + edge: Edge[InterstateEdge] + data_name: str + data: Union[dt.Structure, dt.ContainerArray] + views_constructed: Set[str] + isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] + + def __init__(self, sdfg: SDFG, edge: Edge[InterstateEdge], data_name: str, + data: Union[dt.Structure, dt.ContainerArray], + isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]): + self.sdfg = sdfg + self.edge = edge + self.data_name = data_name + self.data = data + self.views_constructed = set() + self.isedge_lifting_state_dict = isedge_lifting_state_dict + + def _handle_simple_name_access(self, node: ast.Attribute) -> Any: + struct: dt.Structure = self.data + if not node.attr in struct.members: + raise RuntimeError( + f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition' + ) + + # Insert the appropriate view, if it does not exist yet. + view_name = 'v_' + self.data_name + '_' + node.attr + try: + view = self.sdfg.arrays[view_name] + except KeyError: + view = dt.View.view(struct.members[node.attr]) + view_name = self.sdfg.add_datadesc(view_name, view, find_new_name=True) + self.views_constructed.add(view_name) + + # Construct the correct AST replacement node (direct access, i.e., name node). + replacement = ast.Name() + replacement.ctx = ast.Load() + replacement.id = view_name + + # Add access nodes for the view and the original container and connect them appropriately. + lift_state, data_node = self._get_or_create_lifting_state() + view_node = lift_state.add_access(view_name) + lift_state.add_edge(data_node, None, view_node, 'views', + Memlet.from_array(data_node.data + '.' + node.attr, self.data.members[node.attr])) + return self.generic_visit(replacement) + + def _handle_sliced_access(self, node: ast.Attribute, val: ast.Subscript) -> Any: + struct = self.data.stype + if not isinstance(struct, dt.Structure): + raise ValueError('Invalid ContainerArray, can only lift ContainerArrays to Structures') + if not node.attr in struct.members: + raise RuntimeError( + f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition' + ) + + # We first lift the slice into a separate view, and then the attribute access. + slice_view_name = 'v_' + self.data_name + '_slice' + attr_view_name = slice_view_name + '_' + node.attr + try: + slice_view = self.sdfg.arrays[slice_view_name] + except KeyError: + slice_view = dt.View.view(struct) + slice_view_name = self.sdfg.add_datadesc(slice_view_name, slice_view, find_new_name=True) + try: + attr_view = self.sdfg.arrays[attr_view_name] + except KeyError: + member: dt.Data = struct.members[node.attr] + attr_view = dt.View.view(member) + attr_view_name = self.sdfg.add_datadesc(attr_view_name, attr_view, find_new_name=True) + self.views_constructed.add(slice_view_name) + self.views_constructed.add(attr_view_name) + + # Construct the correct AST replacement node (direct access, i.e., name node). + replacement = ast.Name() + replacement.ctx = ast.Load() + replacement.id = attr_view_name + + # Add access nodes for the views to the slice and attribute and connect them appropriately to the original data + # container. + lift_state, data_node = self._get_or_create_lifting_state() + slice_view_node = lift_state.add_access(slice_view_name) + attr_view_node = lift_state.add_access(attr_view_name) + idx = astutils.unparse(val.slice) + if isinstance(val.slice, ast.Tuple): + idx = idx.strip('()') + slice_memlet = Memlet(data_node.data + '[' + idx + ']') + lift_state.add_edge(data_node, None, slice_view_node, 'views', slice_memlet) + attr_memlet = Memlet.from_array(slice_view_name + '.' + node.attr, struct.members[node.attr]) + lift_state.add_edge(slice_view_node, None, attr_view_node, 'views', attr_memlet) + return self.generic_visit(replacement) + + def _get_or_create_lifting_state(self) -> Tuple[SDFGState, nd.AccessNode]: + # Add a state for lifting before the edge, if there isn't one that was created already. + if self.edge.data in self.isedge_lifting_state_dict: + lift_state = self.isedge_lifting_state_dict[self.edge.data] + else: + pre_node: ControlFlowBlock = self.edge.src + lift_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting') + self.isedge_lifting_state_dict[self.edge.data] = lift_state + + # Add a node for the original data container so the view can be connected to it. This may already be a view from + # a previous iteration of lifting, but in that case it is already correctly connected to a root data container. + data_node = None + for dn in lift_state.data_nodes(): + if dn.data == self.data_name: + data_node = dn + break + if data_node is None: + data_node = lift_state.add_access(self.data_name) + + return lift_state, data_node + + def visit_Attribute(self, node: ast.Attribute) -> Any: + if not node.value: + return self.generic_visit(node) + + if isinstance(self.data, dt.Structure): + if isinstance(node.value, ast.Name) and node.value.id == self.data_name: + return self._handle_simple_name_access(node) + elif (isinstance(node.value, ast.Subscript) and isinstance(node.value.slice, ast.Constant) and + node.value.slice.value == 0 and isinstance(node.value.value, ast.Name) and + node.value.value.id == self.data_name): + return self._handle_simple_name_access(node) + return self.generic_visit(node) + else: + # ContainerArray case. + if isinstance(node.value, ast.Name) and node.value.id == self.data_name: + # We are directly accessing a slice of a container array / view. That needs an inserted view to the + # container first. + slice_view_name = 'v_' + self.data_name + '_slice' + try: + slice_view = self.sdfg.arrays[slice_view_name] + except KeyError: + slice_view = dt.View.view(self.data.stype) + slice_view_name = self.sdfg.add_datadesc(slice_view_name, slice_view, find_new_name=True) + self.views_constructed.add(slice_view_name) + + # Add an access node for the slice view and connect it appropriately to the root data container. + lift_state, data_node = self._get_or_create_lifting_state() + slice_view_node = lift_state.add_access(slice_view_name) + lift_state.add_edge(data_node, None, slice_view_node, 'views', + Memlet.from_array(self.data_name, self.sdfg.data(self.data_name))) + elif (isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) and + node.value.value.id == self.data_name): + return self._handle_sliced_access(node, node.value) + return self.generic_visit(node) + + +def _data_containers_in_ast(node: ast.AST, arrnames: Set[str]) -> Set[str]: + result: Set[str] = set() + for subnode in ast.walk(node): + if isinstance(subnode, (ast.Attribute, ast.Subscript)): + data = astutils.rname(subnode.value) + if data in arrnames: + result.add(data) + return result class LiftStructViews(ppl.Pass): """ @@ -200,6 +360,8 @@ class LiftStructViews(ppl.Pass): CATEGORY: str = 'Optimization Preparation' + _isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] = dict() + def modifies(self) -> ppl.Modifies: return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes | ppl.Modifies.Tasklets | ppl.Modifies.Memlets @@ -209,6 +371,40 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {} + def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], result: Dict[str, Set[str]]) -> bool: + lifted_something = False + for k in edge.data.assignments.keys(): + assignment = edge.data.assignments[k] + assignment_str = str(assignment) + assignment_ast = ast.parse(assignment_str) + data_in_edge = _data_containers_in_ast(assignment_ast, cfg.sdfg.arrays.keys()) + for data in data_in_edge: + if '.' in data: + continue + container = cfg.sdfg.arrays[data] + if isinstance(container, (dt.Structure, dt.ContainerArray)): + visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict) + new_code = visitor.visit(assignment_ast) + edge.data.assignments[k] = astutils.unparse(new_code) + assignment_ast = new_code + result[data].update(visitor.views_constructed) + lifted_something = True + if not edge.data.is_unconditional(): + condition_ast = edge.data.condition.code[0] + data_in_edge = _data_containers_in_ast(condition_ast, cfg.sdfg.arrays.keys()) + for data in data_in_edge: + if '.' in data: + continue + container = cfg.sdfg.arrays[data] + if isinstance(container, (dt.Structure, dt.ContainerArray)): + visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict) + new_code = visitor.visit(condition_ast) + edge.data.condition = CodeBlock([new_code]) + condition_ast = new_code + result[data].update(visitor.views_constructed) + lifted_something = True + return lifted_something + def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.Tasklet, edge: MultiConnectorEdge[Memlet], data: dt.Structure, connector: str, direction: dirtype) -> Set[str]: @@ -251,23 +447,34 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Dict[str, Set[str]]]: result = defaultdict(set) lifted_something = False - for nsdfg in sdfg.all_sdfgs_recursive(): - for state in nsdfg.states(): - for node in state.data_nodes(): - cont = nsdfg.data(node.data) - if (isinstance(cont, (dt.Structure, dt.StructureView, dt.StructureReference)) or - (isinstance(cont, (dt.ContainerView, dt.ContainerArray, dt.ContainerArrayReference)) and - isinstance(cont.stype, dt.Structure))): - for oedge in state.out_edges(node): - if isinstance(oedge.dst, nd.Tasklet): - res = self._lift_tasklet(state, node, oedge.dst, oedge, cont, oedge.dst_conn, 'in') - result[node.data].update(res) - lifted_something = True - for iedge in state.in_edges(node): - if isinstance(iedge.src, nd.Tasklet): - res = self._lift_tasklet(state, node, iedge.src, iedge, cont, iedge.src_conn, 'out') - result[node.data].update(res) - lifted_something = True + while True: + lifted_something_this_round = False + for cfg in sdfg.all_control_flow_regions(recursive=True): + for block in cfg.nodes(): + if isinstance(block, SDFGState): + for node in block.data_nodes(): + cont = cfg.sdfg.data(node.data) + if (isinstance(cont, (dt.Structure, dt.StructureView, dt.StructureReference)) or + (isinstance(cont, (dt.ContainerView, dt.ContainerArray, dt.ContainerArrayReference)) and + isinstance(cont.stype, dt.Structure))): + for oedge in block.out_edges(node): + if isinstance(oedge.dst, nd.Tasklet): + res = self._lift_tasklet(block, node, oedge.dst, oedge, cont, oedge.dst_conn, + 'in') + result[node.data].update(res) + lifted_something_this_round = True + for iedge in block.in_edges(node): + if isinstance(iedge.src, nd.Tasklet): + res = self._lift_tasklet(block, node, iedge.src, iedge, cont, iedge.src_conn, + 'out') + result[node.data].update(res) + lifted_something_this_round = True + for edge in cfg.edges(): + lifted_something_this_round |= self._lift_isedge(cfg, edge, result) + if not lifted_something_this_round: + break + else: + lifted_something = True if not lifted_something: return None