Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add external nested SDFG capabilities #1795

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 56 additions & 44 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ class NestedSDFG(CodeNode):

# NOTE: We cannot use SDFG as the type because of an import loop
sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True)
ext_sdfg_path = Property(dtype=str, default=None, allow_none=True,
desc='Path to a file containing the SDFG for this nested SDFG')
schedule = EnumProperty(dtype=dtypes.ScheduleType,
desc="SDFG schedule",
allow_none=True,
Expand All @@ -569,22 +571,30 @@ class NestedSDFG(CodeNode):

def __init__(self,
label,
sdfg,
sdfg: Optional['dace.SDFG'],
inputs: Set[str],
outputs: Set[str],
symbol_mapping: Dict[str, Any] = None,
schedule=dtypes.ScheduleType.Default,
location=None,
debuginfo=None):
from dace.sdfg import SDFG
debuginfo=None,
path: Optional[str] = None):
super(NestedSDFG, self).__init__(label, location, inputs, outputs)

# Properties
self.sdfg: SDFG = sdfg
self.sdfg: 'dace.SDFG' = sdfg
self.ext_sdfg_path = path
self.symbol_mapping = symbol_mapping or {}
self.schedule = schedule
self.debuginfo = debuginfo

def load_external(self, context: Optional['dace.SDFGState']) -> None:
if self.sdfg is None and self.ext_sdfg_path is not None:
self.sdfg = dace.SDFG.from_file(self.ext_sdfg_path)
self.sdfg.parent_nsdfg_node = self
self.sdfg.parent = context
self.sdfg.parent_sdfg = context.sdfg if context else None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
Expand All @@ -607,14 +617,14 @@ def from_json(json_obj, context=None):

dace.serialize.set_properties_from_json(ret, json_obj, context)

if context and 'sdfg_state' in context:
ret.sdfg.parent = context['sdfg_state']
if context and 'sdfg' in context:
ret.sdfg.parent_sdfg = context['sdfg']

ret.sdfg.parent_nsdfg_node = ret
if ret.sdfg is not None:
if context and 'sdfg_state' in context:
ret.sdfg.parent = context['sdfg_state']
if context and 'sdfg' in context:
ret.sdfg.parent_sdfg = context['sdfg']
ret.sdfg.parent_nsdfg_node = ret

ret.sdfg.update_cfg_list([])
ret.sdfg.update_cfg_list([])

return ret

Expand Down Expand Up @@ -664,28 +674,29 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
for out_conn in self.out_connectors:
if not dtypes.validate_name(out_conn):
raise NameError('Invalid output connector "%s"' % out_conn)
if self.sdfg.parent_nsdfg_node is not self:
raise ValueError('Parent nested SDFG node not properly set')
if self.sdfg.parent is not state:
raise ValueError('Parent state not properly set for nested SDFG node')
if self.sdfg.parent_sdfg is not sdfg:
raise ValueError('Parent SDFG not properly set for nested SDFG node')

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)
if self.sdfg:
if self.sdfg.parent_nsdfg_node is not self:
raise ValueError('Parent nested SDFG node not properly set')
if self.sdfg.parent is not state:
raise ValueError('Parent state not properly set for nested SDFG node')
if self.sdfg.parent_sdfg is not sdfg:
raise ValueError('Parent SDFG not properly set for nested SDFG node')

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)

# Validate inout connectors
from dace.sdfg import utils # Avoids circular import
Expand All @@ -706,17 +717,18 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
f"output ({outputs}) arrays")

# Validate undefined symbols
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols))
extra_symbols = self.symbol_mapping.keys() - symbols
if len(extra_symbols) > 0:
# TODO: Elevate to an error?
warnings.warn(f"{self.label} maps to unused symbol(s): {extra_symbols}")

# Recursively validate nested SDFG
self.sdfg.validate(references, **context)
if self.sdfg:
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols))
extra_symbols = self.symbol_mapping.keys() - symbols
if len(extra_symbols) > 0:
# TODO: Elevate to an error?
warnings.warn(f"{self.label} maps to unused symbol(s): {extra_symbols}")

# Recursively validate nested SDFG
self.sdfg.validate(references, **context)


# ------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,6 +2341,10 @@ def compile(self, output_file=None, validate=True,
# if the codegen modifies the SDFG (thereby changing its hash)
sdfg.build_folder = build_folder

# Ensure external nested SDFGs are loaded.
for _ in sdfg.all_sdfgs_recursive(load_ext=True):
pass

# Rename SDFG to avoid runtime issues with clashing names
index = 0
while sdfg.is_loaded():
Expand Down
85 changes: 48 additions & 37 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]:
def all_nodes_recursive(self, predicate = None) -> Iterator[Tuple[NodeT, GraphT]]:
for node in self.nodes():
yield node, self
if isinstance(node, nd.NestedSDFG):
if isinstance(node, nd.NestedSDFG) and node.sdfg:
if predicate is None or predicate(node, self):
yield from node.sdfg.all_nodes_recursive(predicate)

Expand Down Expand Up @@ -1380,7 +1380,7 @@ def add_node(self, node):
if not isinstance(node, nd.Node):
raise TypeError("Expected Node, got " + type(node).__name__ + " (" + str(node) + ")")
# Correct nested SDFG's parent attributes
if isinstance(node, nd.NestedSDFG):
if isinstance(node, nd.NestedSDFG) and node.sdfg is not None:
node.sdfg.parent = self
node.sdfg.parent_sdfg = self.sdfg
node.sdfg.parent_nsdfg_node = node
Expand Down Expand Up @@ -1667,7 +1667,7 @@ def add_tasklet(

def add_nested_sdfg(
self,
sdfg: 'SDFG',
sdfg: Optional['SDFG'],
parent,
inputs: Union[Set[str], Dict[str, dtypes.typeclass]],
outputs: Union[Set[str], Dict[str, dtypes.typeclass]],
Expand All @@ -1676,16 +1676,21 @@ def add_nested_sdfg(
schedule=dtypes.ScheduleType.Default,
location=None,
debuginfo=None,
external_path: Optional[str] = None,
):
""" Adds a nested SDFG to the SDFG state. """
if name is None:
name = sdfg.label
debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo)

sdfg.parent = self
sdfg.parent_sdfg = self.sdfg
if sdfg is None and external_path is None:
raise ValueError('Neither an SDFG nor an external SDFG path has been provided')

if sdfg is not None:
sdfg.parent = self
sdfg.parent_sdfg = self.sdfg

sdfg.update_cfg_list([])
sdfg.update_cfg_list([])

# Make dictionary of autodetect connector types from set
if isinstance(inputs, (set, collections.abc.KeysView)):
Expand All @@ -1702,35 +1707,37 @@ def add_nested_sdfg(
schedule=schedule,
location=location,
debuginfo=debuginfo,
path=external_path,
)
self.add_node(s)

sdfg.parent_nsdfg_node = s

# Add "default" undefined symbols if None are given
symbols = sdfg.free_symbols
if symbol_mapping is None:
symbol_mapping = {s: s for s in symbols}
s.symbol_mapping = symbol_mapping

# Validate missing symbols
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols and parent:
# If symbols are missing, try to get them from the parent SDFG
parent_mapping = {s: s for s in missing_symbols if s in parent.symbols}
symbol_mapping.update(parent_mapping)
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))
if sdfg is not None:
sdfg.parent_nsdfg_node = s

# Add new global symbols to nested SDFG
from dace.codegen.tools.type_inference import infer_expr_type
for sym, symval in s.symbol_mapping.items():
if sym not in sdfg.symbols:
# TODO: Think of a better way to avoid calling
# symbols_defined_at in this moment
sdfg.add_symbol(sym, infer_expr_type(symval, self.sdfg.symbols) or dtypes.typeclass(int))
# Add "default" undefined symbols if None are given
symbols = sdfg.free_symbols
if symbol_mapping is None:
symbol_mapping = {s: s for s in symbols}
s.symbol_mapping = symbol_mapping

# Validate missing symbols
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols and parent:
# If symbols are missing, try to get them from the parent SDFG
parent_mapping = {s: s for s in missing_symbols if s in parent.symbols}
symbol_mapping.update(parent_mapping)
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))

# Add new global symbols to nested SDFG
from dace.codegen.tools.type_inference import infer_expr_type
for sym, symval in s.symbol_mapping.items():
if sym not in sdfg.symbols:
# TODO: Think of a better way to avoid calling
# symbols_defined_at in this moment
sdfg.add_symbol(sym, infer_expr_type(symval, self.sdfg.symbols) or dtypes.typeclass(int))

return s

Expand Down Expand Up @@ -2818,23 +2825,27 @@ def add_state_after(self,
###################################################################
# Traversal methods

def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']:
def all_control_flow_regions(self, recursive=False, load_ext=False) -> Iterator['ControlFlowRegion']:
""" Iterate over this and all nested control flow regions. """
yield self
for block in self.nodes():
if isinstance(block, SDFGState) and recursive:
for node in block.nodes():
if isinstance(node, nd.NestedSDFG):
yield from node.sdfg.all_control_flow_regions(recursive=recursive)
if node.sdfg:
yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif load_ext:
node.load_external(block)
yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif isinstance(block, ControlFlowRegion):
yield from block.all_control_flow_regions(recursive=recursive)
yield from block.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif isinstance(block, ConditionalBlock):
for _, branch in block.branches:
yield from branch.all_control_flow_regions(recursive=recursive)
yield from branch.all_control_flow_regions(recursive=recursive, load_ext=load_ext)

def all_sdfgs_recursive(self) -> Iterator['SDFG']:
def all_sdfgs_recursive(self, load_ext=False) -> Iterator['SDFG']:
""" Iterate over this and all nested SDFGs. """
for cfg in self.all_control_flow_regions(recursive=True):
for cfg in self.all_control_flow_regions(recursive=True, load_ext=load_ext):
if isinstance(cfg, dace.SDFG):
yield cfg

Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/fusion_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def modifies(self) -> ppl.Modifies:
def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]:
modified = 0
for node, state in sdfg.all_nodes_recursive():
if not isinstance(node, nodes.NestedSDFG):
if not isinstance(node, nodes.NestedSDFG) or node.sdfg is None:
continue
was_modified = False
if node.sdfg.parent_nsdfg_node is not node:
Expand Down
57 changes: 56 additions & 1 deletion tests/nested_sdfg_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import os
import tempfile
import numpy as np

import dace as dp
Expand Down Expand Up @@ -54,5 +56,58 @@ def do():
assert diff <= 1e-5


def test_external_nsdfg():
N = dp.symbol('N')

@dp.program
def sdfg_internal(input: dp.float32, output: dp.float32[1]):
@dp.tasklet
def init():
out >> output
out = input

for k in range(4):

@dp.tasklet
def do():
oin << output
out >> output
out = oin * input


# Construct SDFG
mysdfg = SDFG('outer_sdfg')
state = mysdfg.add_state()
A = state.add_array('A', [N, N], dp.float32)
B = state.add_array('B', [N, N], dp.float32)

map_entry, map_exit = state.add_map('elements', [('i', '0:N'), ('j', '0:N')])
internal = sdfg_internal.to_sdfg()
fd, filename = tempfile.mkstemp(suffix='.sdfg')
internal.save(filename)
nsdfg = state.add_nested_sdfg(None, mysdfg, {'input'}, {'output'}, name='sdfg_internal', external_path=filename)

# Add edges
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='input', memlet=Memlet.simple(A, 'i,j'))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='output', memlet=Memlet.simple(B, 'i,j'))


N = 64

input = dp.ndarray([N, N], dp.float32)
output = dp.ndarray([N, N], dp.float32)
input[:] = np.random.rand(N, N).astype(dp.float32.type)
output[:] = dp.float32(0)

mysdfg(A=input, B=output, N=N)

diff = np.linalg.norm(output - np.power(input, 5)) / (N * N)
print("Difference:", diff)
assert diff <= 1e-5

os.close(fd)


if __name__ == "__main__":
test()
test_external_nsdfg()