Skip to content

Commit

Permalink
Let LiftStructViews lift interstate edge struct accesses to views (#1827
Browse files Browse the repository at this point in the history
)

`LiftStructViews` is a necessary pass for ensuring structure member
accesses and container array accesses work correctly, by lifting
accesses to them into a 'tower of views' that correctly traces their
accesses to the root data container. However, this previously did not
consider interstate edges. This PR fixes this by making two changes:
- Allow `LiftStructViews` to lift struct accesses on interstate edges to
the correct tower of views, by inserting a lifting state right before
the interstate edge, where views are constructed. The interstate edge is
then re-written to access the correct view instead.
- Fix `ArrayElimination` to correctly handle these views, which are not
directly connected to other nodes, by preventing it from incorrectly
merging or removing them. This is done by allowing source/sink view
nodes to be merged correctly, i.e., when they are 'the same view', i.e.,
their view memlets are identical.

This shows an example of the constructed tower of views for a complex
interstate edge struct access:


![image](https://github.com/user-attachments/assets/bc8173fe-a6a9-4592-97d9-f1b02a6da740)
  • Loading branch information
phschaad authored Dec 16, 2024
1 parent 6156243 commit e82870a
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 35 deletions.
6 changes: 3 additions & 3 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,16 +867,16 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg
return None

in_edge = in_edges[0]
out_edge = out_edges[0]
out_edge = out_edges[0] if len(out_edges) > 0 else None

# If there is one incoming and one outgoing edge, and one leads to a code
# node, the one that leads to an access node is the viewed data.
inmpath = state.memlet_path(in_edge)
outmpath = state.memlet_path(out_edge)
outmpath = state.memlet_path(out_edge) if out_edge else None
src_is_data, dst_is_data = False, False
if isinstance(inmpath[0].src, nd.AccessNode):
src_is_data = True
if isinstance(outmpath[-1].dst, nd.AccessNode):
if outmpath and isinstance(outmpath[-1].dst, nd.AccessNode):
dst_is_data = True

if src_is_data and not dst_is_data:
Expand Down
60 changes: 47 additions & 13 deletions dace/transformation/passes/array_elimination.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set

from dace import SDFG, SDFGState, data, properties
from dace.memlet import Memlet
from dace.sdfg import nodes
from dace.sdfg.analysis import cfg
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.validation import InvalidSDFGNodeError
from dace.transformation import pass_pipeline as ppl, transformation
from dace.transformation.dataflow import (RedundantArray, RedundantReadSlice, RedundantSecondArray, RedundantWriteSlice,
SqueezeViewRemove, UnsqueezeViewRemove, RemoveSliceView)
Expand Down Expand Up @@ -66,9 +69,6 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S
removed_nodes = self.merge_access_nodes(state, access_nodes, lambda n: state.in_degree(n) == 0)
removed_nodes |= self.merge_access_nodes(state, access_nodes, lambda n: state.out_degree(n) == 0)

# Update access nodes with merged nodes
access_nodes = {k: [n for n in v if n not in removed_nodes] for k, v in access_nodes.items()}

# Remove redundant views
removed_nodes |= self.remove_redundant_views(sdfg, state, access_nodes)

Expand Down Expand Up @@ -105,25 +105,59 @@ def merge_access_nodes(self, state: SDFGState, access_nodes: Dict[str, List[node
Merges access nodes that follow the same conditions together to the first access node.
"""
removed_nodes: Set[nodes.AccessNode] = set()
for nodeset in access_nodes.values():
for data_container in access_nodes.keys():
nodeset = access_nodes[data_container]
if len(nodeset) > 1:
# Merge all other access nodes to the first one
first_node = nodeset[0]
if not condition(first_node):
# Merge all other access nodes to the first one that fits the condition, if one exists.
first_node = None
first_node_idx = 0
for i, node in enumerate(nodeset[:-1]):
if condition(node):
first_node = node
first_node_idx = i
break
if first_node is None:
continue
for node in nodeset[1:]:

for node in nodeset[first_node_idx + 1:]:
if not condition(node):
continue

# Reconnect edges to first node
for edge in state.all_edges(node):
# Reconnect edges to first node.
# If we are handling views, we do not want to add more than one edge going into a 'views' connector,
# so we only merge nodes if the memlets match exactly (which they should). But in that case without
# copying the edge.
edges: List[MultiConnectorEdge[Memlet]] = state.all_edges(node)
other_edges: List[MultiConnectorEdge[Memlet]] = []
for edge in edges:
if edge.dst is node:
state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data)
if edge.dst_conn == 'views':
other_edges = list(state.in_edges_by_connector(first_node, 'views'))
if len(other_edges) != 1:
raise InvalidSDFGNodeError('Multiple edges connected to views connector',
state.sdfg, state.block_id, state.node_id(first_node))
other_view_edge = other_edges[0]
if other_view_edge.data != edge.data:
# The memlets do not match, skip the node.
continue
else:
state.add_edge(edge.src, edge.src_conn, first_node, edge.dst_conn, edge.data)
else:
state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
if edge.src_conn == 'views':
other_edges = list(state.out_edges_by_connector(first_node, 'views'))
if len(other_edges) != 1:
raise InvalidSDFGNodeError('Multiple edges connected to views connector',
state.sdfg, state.block_id, state.node_id(first_node))
other_view_edge = other_edges[0]
if other_view_edge.data != edge.data:
# The memlets do not match, skip the node.
continue
else:
state.add_edge(first_node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
# Remove merged node and associated edges
state.remove_node(node)
removed_nodes.add(node)
access_nodes[data_container] = [n for n in nodeset if n not in removed_nodes]
return removed_nodes

def remove_redundant_views(self, sdfg: SDFG, state: SDFGState, access_nodes: Dict[str, List[nodes.AccessNode]]):
Expand Down
Loading

0 comments on commit e82870a

Please sign in to comment.