Skip to content

Commit

Permalink
Merge branch 'main' into symbol_redeclaration_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Dec 2, 2024
2 parents 0b31f9a + 0e2b39a commit 26bd8ba
Show file tree
Hide file tree
Showing 12 changed files with 833 additions and 98 deletions.
3 changes: 3 additions & 0 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,6 +2620,9 @@ def _generate_NestedSDFG(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub

def _generate_MapExit(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int,
node: nodes.MapExit, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None:
if isinstance(node, nodes.MapExit) and node.map.gpu_force_syncthreads:
callsite_stream.write('__syncthreads();', cfg, state_id)

if node.map.schedule == dtypes.ScheduleType.GPU_Device:
# Remove grid invocation conditions
for i in range(len(node.map.params)):
Expand Down
100 changes: 57 additions & 43 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ class NestedSDFG(CodeNode):

# NOTE: We cannot use SDFG as the type because of an import loop
sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True)
ext_sdfg_path = Property(dtype=str, default=None, allow_none=True,
desc='Path to a file containing the SDFG for this nested SDFG')
schedule = EnumProperty(dtype=dtypes.ScheduleType,
desc="SDFG schedule",
allow_none=True,
Expand All @@ -569,22 +571,30 @@ class NestedSDFG(CodeNode):

def __init__(self,
label,
sdfg,
sdfg: Optional['dace.SDFG'],
inputs: Set[str],
outputs: Set[str],
symbol_mapping: Dict[str, Any] = None,
schedule=dtypes.ScheduleType.Default,
location=None,
debuginfo=None):
from dace.sdfg import SDFG
debuginfo=None,
path: Optional[str] = None):
super(NestedSDFG, self).__init__(label, location, inputs, outputs)

# Properties
self.sdfg: SDFG = sdfg
self.sdfg: 'dace.SDFG' = sdfg
self.ext_sdfg_path = path
self.symbol_mapping = symbol_mapping or {}
self.schedule = schedule
self.debuginfo = debuginfo

def load_external(self, context: Optional['dace.SDFGState']) -> None:
if self.sdfg is None and self.ext_sdfg_path is not None:
self.sdfg = dace.SDFG.from_file(self.ext_sdfg_path)
self.sdfg.parent_nsdfg_node = self
self.sdfg.parent = context
self.sdfg.parent_sdfg = context.sdfg if context else None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
Expand All @@ -607,14 +617,14 @@ def from_json(json_obj, context=None):

dace.serialize.set_properties_from_json(ret, json_obj, context)

if context and 'sdfg_state' in context:
ret.sdfg.parent = context['sdfg_state']
if context and 'sdfg' in context:
ret.sdfg.parent_sdfg = context['sdfg']
if ret.sdfg is not None:
if context and 'sdfg_state' in context:
ret.sdfg.parent = context['sdfg_state']
if context and 'sdfg' in context:
ret.sdfg.parent_sdfg = context['sdfg']
ret.sdfg.parent_nsdfg_node = ret

ret.sdfg.parent_nsdfg_node = ret

ret.sdfg.update_cfg_list([])
ret.sdfg.update_cfg_list([])

return ret

Expand Down Expand Up @@ -664,28 +674,29 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
for out_conn in self.out_connectors:
if not dtypes.validate_name(out_conn):
raise NameError('Invalid output connector "%s"' % out_conn)
if self.sdfg.parent_nsdfg_node is not self:
raise ValueError('Parent nested SDFG node not properly set')
if self.sdfg.parent is not state:
raise ValueError('Parent state not properly set for nested SDFG node')
if self.sdfg.parent_sdfg is not sdfg:
raise ValueError('Parent SDFG not properly set for nested SDFG node')

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)
if self.sdfg:
if self.sdfg.parent_nsdfg_node is not self:
raise ValueError('Parent nested SDFG node not properly set')
if self.sdfg.parent is not state:
raise ValueError('Parent state not properly set for nested SDFG node')
if self.sdfg.parent_sdfg is not sdfg:
raise ValueError('Parent SDFG not properly set for nested SDFG node')

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)

# Validate inout connectors
from dace.sdfg import utils # Avoids circular import
Expand All @@ -706,17 +717,18 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
f"output ({outputs}) arrays")

# Validate undefined symbols
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols))
extra_symbols = self.symbol_mapping.keys() - symbols
if len(extra_symbols) > 0:
# TODO: Elevate to an error?
warnings.warn(f"{self.label} maps to unused symbol(s): {extra_symbols}")
if self.sdfg:
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols))
extra_symbols = self.symbol_mapping.keys() - symbols
if len(extra_symbols) > 0:
# TODO: Elevate to an error?
warnings.warn(f"{self.label} maps to unused symbol(s): {extra_symbols}")

# Recursively validate nested SDFG
self.sdfg.validate(references, **context)
# Recursively validate nested SDFG
self.sdfg.validate(references, **context)


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -930,6 +942,8 @@ class Map(object):
"(including tuples) sets it explicitly.",
serialize_if=lambda m: m.schedule in dtypes.GPU_SCHEDULES)

gpu_force_syncthreads = Property(dtype=bool, desc="Force a call to the __syncthreads for the map", default=False)

def __init__(self,
label,
params,
Expand Down
18 changes: 13 additions & 5 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def clear_data_reports(self):

def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, **kwargs):
"""
Invokes an SDFG with an instrumented data report, generating and compiling code if necessary.
Invokes an SDFG with an instrumented data report, generating and compiling code if necessary.
Arguments given as ``args`` and ``kwargs`` will be overriden by the data containers defined in the report.
:param dreport: The instrumented data report to use upon calling.
Expand Down Expand Up @@ -2341,6 +2341,10 @@ def compile(self, output_file=None, validate=True,
# if the codegen modifies the SDFG (thereby changing its hash)
sdfg.build_folder = build_folder

# Ensure external nested SDFGs are loaded.
for _ in sdfg.all_sdfgs_recursive(load_ext=True):
pass

# Rename SDFG to avoid runtime issues with clashing names
index = 0
while sdfg.is_loaded():
Expand Down Expand Up @@ -2690,7 +2694,7 @@ def apply_transformations_once_everywhere(self,
print_report: Optional[bool] = None,
order_by_transformation: bool = True,
progress: Optional[bool] = None) -> int:
"""
"""
This function applies a transformation or a set of (unique) transformations
until throughout the entire SDFG once. Operates in-place.
Expand Down Expand Up @@ -2738,7 +2742,9 @@ def apply_gpu_transformations(self,
permissive=False,
sequential_innermaps=True,
register_transients=True,
simplify=True):
simplify=True,
host_maps=None,
host_data=None):
""" Applies a series of transformations on the SDFG for it to
generate GPU code.
Expand All @@ -2755,7 +2761,9 @@ def apply_gpu_transformations(self,
self.apply_transformations(GPUTransformSDFG,
options=dict(sequential_innermaps=sequential_innermaps,
register_trans=register_transients,
simplify=simplify),
simplify=simplify,
host_maps=host_maps,
host_data=host_data),
validate=validate,
validate_all=validate_all,
permissive=permissive,
Expand Down Expand Up @@ -2806,7 +2814,7 @@ def expand_library_nodes(self, recursive=True):

def generate_code(self):
""" Generates code from this SDFG and returns it.
:return: A list of `CodeObject` objects containing the generated
code of different files and languages.
"""
Expand Down
85 changes: 48 additions & 37 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]:
def all_nodes_recursive(self, predicate = None) -> Iterator[Tuple[NodeT, GraphT]]:
for node in self.nodes():
yield node, self
if isinstance(node, nd.NestedSDFG):
if isinstance(node, nd.NestedSDFG) and node.sdfg:
if predicate is None or predicate(node, self):
yield from node.sdfg.all_nodes_recursive(predicate)

Expand Down Expand Up @@ -1380,7 +1380,7 @@ def add_node(self, node):
if not isinstance(node, nd.Node):
raise TypeError("Expected Node, got " + type(node).__name__ + " (" + str(node) + ")")
# Correct nested SDFG's parent attributes
if isinstance(node, nd.NestedSDFG):
if isinstance(node, nd.NestedSDFG) and node.sdfg is not None:
node.sdfg.parent = self
node.sdfg.parent_sdfg = self.sdfg
node.sdfg.parent_nsdfg_node = node
Expand Down Expand Up @@ -1667,7 +1667,7 @@ def add_tasklet(

def add_nested_sdfg(
self,
sdfg: 'SDFG',
sdfg: Optional['SDFG'],
parent,
inputs: Union[Set[str], Dict[str, dtypes.typeclass]],
outputs: Union[Set[str], Dict[str, dtypes.typeclass]],
Expand All @@ -1676,16 +1676,21 @@ def add_nested_sdfg(
schedule=dtypes.ScheduleType.Default,
location=None,
debuginfo=None,
external_path: Optional[str] = None,
):
""" Adds a nested SDFG to the SDFG state. """
if name is None:
name = sdfg.label
debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo)

sdfg.parent = self
sdfg.parent_sdfg = self.sdfg
if sdfg is None and external_path is None:
raise ValueError('Neither an SDFG nor an external SDFG path has been provided')

if sdfg is not None:
sdfg.parent = self
sdfg.parent_sdfg = self.sdfg

sdfg.update_cfg_list([])
sdfg.update_cfg_list([])

# Make dictionary of autodetect connector types from set
if isinstance(inputs, (set, collections.abc.KeysView)):
Expand All @@ -1702,35 +1707,37 @@ def add_nested_sdfg(
schedule=schedule,
location=location,
debuginfo=debuginfo,
path=external_path,
)
self.add_node(s)

sdfg.parent_nsdfg_node = s

# Add "default" undefined symbols if None are given
symbols = sdfg.free_symbols
if symbol_mapping is None:
symbol_mapping = {s: s for s in symbols}
s.symbol_mapping = symbol_mapping

# Validate missing symbols
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols and parent:
# If symbols are missing, try to get them from the parent SDFG
parent_mapping = {s: s for s in missing_symbols if s in parent.symbols}
symbol_mapping.update(parent_mapping)
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))
if sdfg is not None:
sdfg.parent_nsdfg_node = s

# Add new global symbols to nested SDFG
from dace.codegen.tools.type_inference import infer_expr_type
for sym, symval in s.symbol_mapping.items():
if sym not in sdfg.symbols:
# TODO: Think of a better way to avoid calling
# symbols_defined_at in this moment
sdfg.add_symbol(sym, infer_expr_type(symval, self.sdfg.symbols) or dtypes.typeclass(int))
# Add "default" undefined symbols if None are given
symbols = sdfg.free_symbols
if symbol_mapping is None:
symbol_mapping = {s: s for s in symbols}
s.symbol_mapping = symbol_mapping

# Validate missing symbols
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols and parent:
# If symbols are missing, try to get them from the parent SDFG
parent_mapping = {s: s for s in missing_symbols if s in parent.symbols}
symbol_mapping.update(parent_mapping)
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))

# Add new global symbols to nested SDFG
from dace.codegen.tools.type_inference import infer_expr_type
for sym, symval in s.symbol_mapping.items():
if sym not in sdfg.symbols:
# TODO: Think of a better way to avoid calling
# symbols_defined_at in this moment
sdfg.add_symbol(sym, infer_expr_type(symval, self.sdfg.symbols) or dtypes.typeclass(int))

return s

Expand Down Expand Up @@ -2818,23 +2825,27 @@ def add_state_after(self,
###################################################################
# Traversal methods

def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']:
def all_control_flow_regions(self, recursive=False, load_ext=False) -> Iterator['ControlFlowRegion']:
""" Iterate over this and all nested control flow regions. """
yield self
for block in self.nodes():
if isinstance(block, SDFGState) and recursive:
for node in block.nodes():
if isinstance(node, nd.NestedSDFG):
yield from node.sdfg.all_control_flow_regions(recursive=recursive)
if node.sdfg:
yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif load_ext:
node.load_external(block)
yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif isinstance(block, ControlFlowRegion):
yield from block.all_control_flow_regions(recursive=recursive)
yield from block.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif isinstance(block, ConditionalBlock):
for _, branch in block.branches:
yield from branch.all_control_flow_regions(recursive=recursive)
yield from branch.all_control_flow_regions(recursive=recursive, load_ext=load_ext)

def all_sdfgs_recursive(self) -> Iterator['SDFG']:
def all_sdfgs_recursive(self, load_ext=False) -> Iterator['SDFG']:
""" Iterate over this and all nested SDFGs. """
for cfg in self.all_control_flow_regions(recursive=True):
for cfg in self.all_control_flow_regions(recursive=True, load_ext=load_ext):
if isinstance(cfg, dace.SDFG):
yield cfg

Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def validate_state(state: 'dace.sdfg.SDFGState',
if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data else src_node)

if isinstance(subset_node, nd.AccessNode):
arr = sdfg.arrays[subset_node.data]
arr = sdfg.arrays[e.data.data]
# Dimensionality
if e.data.subset.dims() != len(arr.shape):
raise InvalidSDFGEdgeError(
Expand Down
Loading

0 comments on commit 26bd8ba

Please sign in to comment.