Skip to content

Commit

Permalink
Merge pull request #1221 from spcl/fix-nested-sdfg-deepcopy
Browse files Browse the repository at this point in the history
Fix-nested-sdfg-deepcopy
  • Loading branch information
alexnick83 authored Mar 20, 2023
2 parents ce68be9 + 3cf917f commit 202ad49
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 1 deletion.
10 changes: 10 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,16 @@ def __init__(self,
self.symbol_mapping = symbol_mapping or {}
self.schedule = schedule
self.debuginfo = debuginfo

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, dcpy(v, memo))
if result._sdfg is not None:
result._sdfg.parent_nsdfg_node = result
return result

@staticmethod
def from_json(json_obj, context=None):
Expand Down
32 changes: 31 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,34 @@ def __init__(self,
self._orig_name = name
self._num = 0

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# Skip derivative attributes
if k in ('_cached_start_state', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node',
'_sdfg_list', '_transformation_hist'):
continue
setattr(result, k, copy.deepcopy(v, memo))
# Copy edges and nodes
result._edges = copy.deepcopy(self._edges, memo)
result._nodes = copy.deepcopy(self._nodes, memo)
result._cached_start_state = copy.deepcopy(self._cached_start_state, memo)
# Copy parent attributes
for k in ('_parent', '_parent_sdfg', '_parent_nsdfg_node'):
if id(getattr(self, k)) in memo:
setattr(result, k, memo[id(getattr(self, k))])
else:
setattr(result, k, None)
# Copy SDFG list and transformation history
if hasattr(self, '_transformation_hist'):
setattr(result, '_transformation_hist', copy.deepcopy(self._transformation_hist, memo))
result._sdfg_list = []
if self._parent_sdfg is None:
result._sdfg_list = result.reset_sdfg_list()
return result

@property
def sdfg_id(self):
"""
Expand Down Expand Up @@ -520,6 +548,7 @@ def hash_sdfg(self, jsondict: Optional[Dict[str, Any]] = None) -> str:
:param jsondict: If not None, uses given JSON dictionary as input.
:return: The hash (in SHA-256 format).
"""

def keyword_remover(json_obj: Any, last_keyword=""):
# Makes non-unique in SDFG hierarchy v2
# Recursively remove attributes from the SDFG which are not used in
Expand Down Expand Up @@ -1910,7 +1939,8 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str
if find_new_name:
name = self._find_new_name(name)
else:
raise NameError('Array or Stream with name "%s" already exists ' "in SDFG" % name)
raise NameError('Array or Stream with name "%s" already exists '
"in SDFG" % name)
self._arrays[name] = datadesc

# Add free symbols to the SDFG global symbol storage
Expand Down
21 changes: 21 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,22 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None):
self.nosync = False
self.location = location if location is not None else {}
self._default_lineinfo = None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
for node in result.nodes():
if isinstance(node, nd.NestedSDFG):
try:
node.sdfg.parent = result
except AttributeError:
# NOTE: There are cases where a NestedSDFG does not have `sdfg` attribute.
# TODO: Investigate why this happens.
pass
return result

@property
def parent(self):
Expand Down Expand Up @@ -819,6 +835,11 @@ def all_edges_and_connectors(self, *nodes):
def add_node(self, node):
if not isinstance(node, nd.Node):
raise TypeError("Expected Node, got " + type(node).__name__ + " (" + str(node) + ")")
# Correct nested SDFG's parent attributes
if isinstance(node, nd.NestedSDFG):
node.sdfg.parent = self
node.sdfg.parent_sdfg = self.parent
node.sdfg.parent_nsdfg_node = node
self._clear_scopedict_cache()
return super(SDFGState, self).add_node(node)

Expand Down
16 changes: 16 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,3 +1629,19 @@ def map_view_to_array(vdesc: dt.View, adesc: dt.Array,
squeezed.append(i)

return dimension_mapping, unsqueezed, squeezed


def check_sdfg(sdfg: SDFG):
""" Checks that the parent attributes of an SDFG are correct.
:param sdfg: The SDFG to check.
:raises AssertionError: If any of the parent attributes are incorrect.
"""
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
assert node.sdfg.parent_nsdfg_node is node
assert node.sdfg.parent is state
assert node.sdfg.parent_sdfg is sdfg
assert node.sdfg.parent.parent is sdfg
check_sdfg(node.sdfg)
2 changes: 2 additions & 0 deletions dace/transformation/interstate/multistate_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
# Remove nested SDFG and state
sdfg.remove_node(outer_state)

sdfg._sdfg_list = sdfg.reset_sdfg_list()

return nsdfg.nodes()

# def _modify_access_to_access(
Expand Down
2 changes: 2 additions & 0 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@ def apply(self, state: SDFGState, sdfg: SDFG):
for dnode in state.data_nodes():
if state.degree(dnode) == 0 and dnode not in isolated_nodes:
state.remove_node(dnode)

sdfg._sdfg_list = sdfg.reset_sdfg_list()

def _modify_access_to_access(self,
input_edges: Dict[nodes.Node, MultiConnectorEdge],
Expand Down
1 change: 1 addition & 0 deletions dace/transformation/subgraph/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool:
# deepcopy
graph_indices = [i for (i, n) in enumerate(graph.nodes()) if n in subgraph]
sdfg_copy = copy.deepcopy(sdfg)
sdfg_copy.reset_sdfg_list()
graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)]
subgraph_copy = SubgraphView(graph_copy, [graph_copy.nodes()[i] for i in graph_indices])
expansion.sdfg_id = sdfg_copy.sdfg_id
Expand Down
153 changes: 153 additions & 0 deletions tests/sdfg/nested_sdfg_deepcopy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests deepcopying (nested) SDFGs. """
import copy
import dace
import numpy as np


def test_deepcopy_same_state():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state = sdfg.add_state('state')

nsdfg = dace.SDFG('nested')
nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

state.add_node(copy_nsdfg)
assert copy_nsdfg.sdfg.parent is state
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_same_state_edge():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state = sdfg.add_state('state')

nsdfg = dace.SDFG('nested')
nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

state.add_edge(nsdfg_node, None, copy_nsdfg, None, dace.Memlet())
assert copy_nsdfg.sdfg.parent is state
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_diff_state():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state_0 = sdfg.add_state('state_0')
state_1 = sdfg.add_state('state_1')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

state_1.add_node(copy_nsdfg)
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_diff_state_edge():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
sdfg.add_array('A', [1], dace.int32)
state_0 = sdfg.add_state('state_0')
state_1 = sdfg.add_state('state_1')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

a = state_1.add_access('A')
state_1.add_edge(a, None, copy_nsdfg, None, dace.Memlet())
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg


def test_deepcopy_diff_sdfg():

sdfg_0 = dace.SDFG('deepcopy_nested_sdfg_0')
state_0 = sdfg_0.add_state('state_0')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

sdfg_1 = dace.SDFG('deepcopy_nested_sdfg_1')
state_1 = sdfg_1.add_state('state_1')

state_1.add_node(copy_nsdfg)
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg_1


def test_deepcopy_diff_sdfg_edge():

sdfg_0 = dace.SDFG('deepcopy_nested_sdfg_0')
state_0 = sdfg_0.add_state('state_0')

nsdfg = dace.SDFG('nested')
nsdfg_node = state_0.add_nested_sdfg(nsdfg, None, {}, {})

copy_nsdfg = copy.deepcopy(nsdfg_node)
assert copy_nsdfg.sdfg.parent_nsdfg_node is copy_nsdfg
assert copy_nsdfg.sdfg.parent is None
assert copy_nsdfg.sdfg.parent_sdfg is None

sdfg_1 = dace.SDFG('deepcopy_nested_sdfg_1')
sdfg_1.add_array('A', [1], dace.int32)
state_1 = sdfg_1.add_state('state_1')

a = state_1.add_access('A')
state_1.add_edge(a, None, copy_nsdfg, None, dace.Memlet())
assert copy_nsdfg.sdfg.parent is state_1
assert copy_nsdfg.sdfg.parent_sdfg is sdfg_1


def test_deepcopy_top_level():

sdfg = dace.SDFG('deepcopy_nested_sdfg')
state = sdfg.add_state('state')

nsdfg = dace.SDFG('nested')
nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {})

copy_sdfg = copy.deepcopy(sdfg)
copy_state = copy_sdfg.states()[0]
copy_nsdfg_node = copy_state.nodes()[0]
for sd in copy_sdfg.all_sdfgs_recursive():
if sd is copy_sdfg:
continue
assert sd.parent_nsdfg_node is copy_nsdfg_node
assert sd.parent is copy_state
assert sd.parent_sdfg is copy_sdfg


if __name__ == '__main__':
test_deepcopy_same_state()
test_deepcopy_same_state_edge()
test_deepcopy_diff_state()
test_deepcopy_diff_state_edge()
test_deepcopy_diff_sdfg()
test_deepcopy_top_level()

0 comments on commit 202ad49

Please sign in to comment.