Skip to content

Commit

Permalink
SDFG API additions for version 1.0 (spcl#1740)
Browse files Browse the repository at this point in the history
This PR adds additional API calls and fields as requested by DaCe users.
This includes:

* `SDFG.auto_optimize`
* `SDFG.regenerate_code`
* `SDFG.as_schedule_tree`
  • Loading branch information
tbennun authored Nov 8, 2024
1 parent d61122d commit 08ec5ea
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
4 changes: 2 additions & 2 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def _load_sdfg(self, path: str, *args, **kwargs):

if sdfg is not None:
# Set regenerate and recompile flags
sdfg._regenerate_code = self.regenerate_code
sdfg.regenerate_code = self.regenerate_code
sdfg._recompile = self.recompile

return sdfg, self._cache.make_key(argtypes, given_args, self.closure_array_keys, self.closure_constant_keys,
Expand Down Expand Up @@ -928,7 +928,7 @@ def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any],
# TODO: Add to parsed SDFG cache

# Set regenerate and recompile flags
sdfg._regenerate_code = self.regenerate_code
sdfg.regenerate_code = self.regenerate_code
sdfg._recompile = self.recompile

return sdfg, cached
61 changes: 60 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from dace.codegen.instrumentation.report import InstrumentationReport
from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeScope


class NestedDict(dict):
Expand Down Expand Up @@ -802,6 +803,14 @@ def start_state(self):
def start_state(self, state_id):
self.start_block = state_id

@property
def regenerate_code(self):
return self._regenerate_code

@regenerate_code.setter
def regenerate_code(self, value):
self._regenerate_code = value

def set_global_code(self, cpp_code: str, location: str = 'frame'):
"""
Sets C++ code that will be generated in a global scope on
Expand Down Expand Up @@ -1070,6 +1079,24 @@ def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args,

##########################################

def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope':
"""
Creates a schedule tree from this SDFG and all nested SDFGs. The schedule tree is a tree of nodes that represent
the execution order of the SDFG.
Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node,
etc.) or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes.
It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example,
erasing an empty if branch, or merging two consecutive for-loops.
:param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might
not be usable after the conversion if ``in_place`` is True!
:return: A schedule tree representing the given SDFG.
"""
# Avoid import loop
from dace.sdfg.analysis.schedule_tree import sdfg_to_tree as s2t
return s2t.as_schedule_tree(self, in_place=in_place)

@property
def build_folder(self) -> str:
""" Returns a relative path to the build cache folder for this SDFG. """
Expand Down Expand Up @@ -2293,7 +2320,7 @@ def compile(self, output_file=None, validate=True,
############################
# DaCe Compilation Process #

if self._regenerate_code or not os.path.isdir(build_folder):
if self.regenerate_code or not os.path.isdir(build_folder):
# Clone SDFG as the other modules may modify its contents
sdfg = copy.deepcopy(self)
# Fix the build folder name on the copied SDFG to avoid it changing
Expand Down Expand Up @@ -2463,6 +2490,38 @@ def simplify(self, validate=True, validate_all=False, verbose=False):
from dace.transformation.passes.simplify import SimplifyPass
return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose).apply_pass(self, {})

def auto_optimize(self,
device: dtypes.DeviceType,
validate: bool = True,
validate_all: bool = False,
symbols: Dict[str, int] = None,
use_gpu_storage: bool = False):
"""
Runs a basic sequence of transformations to optimize a given SDFG to decent
performance. In particular, performs the following:
* Simplify
* Auto-parallelization (loop-to-map)
* Greedy application of SubgraphFusion
* Tiled write-conflict resolution (MapTiling -> AccumulateTransient)
* Tiled stream accumulation (MapTiling -> AccumulateTransient)
* Collapse all maps to parallelize across all dimensions
* Set all library nodes to expand to ``fast`` expansion, which calls
the fastest library on the target device
:param device: the device to optimize for.
:param validate: If True, validates the SDFG after all transformations
have been applied.
:param validate_all: If True, validates the SDFG after every step.
:param symbols: Optional dict that maps symbols (str/symbolic) to int/float
:param use_gpu_storage: If True, changes the storage of non-transient data to GPU global memory.
:note: Operates in-place on the given SDFG.
:note: This function is still experimental and may harm correctness in
certain cases. Please report an issue if it does.
"""
from dace.transformation.auto.auto_optimize import auto_optimize
auto_optimize(device, validate, validate_all, symbols, use_gpu_storage)

def _initialize_transformations_from_type(
self,
xforms: Union[Type, List[Type], 'dace.transformation.PatternTransformation'],
Expand Down

0 comments on commit 08ec5ea

Please sign in to comment.