Skip to content

Commit

Permalink
Merge pull request spcl#96 from spcl/saved-projects
Browse files Browse the repository at this point in the history
DIODE: Remove saved_projects and miscellaneous fixes
  • Loading branch information
tbennun authored Mar 14, 2020
2 parents 9c27fc8 + 24b3a18 commit 3c33fa9
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 78 deletions.
30 changes: 22 additions & 8 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from functools import reduce
import re
from typing import Any, Dict, List, Tuple, Union, Callable
from typing import Any, Dict, List, Tuple, Union, Callable, Optional
import warnings

import dace
Expand Down Expand Up @@ -1453,9 +1453,11 @@ def __init__(
for stmt in _DISALLOWED_STMTS:
setattr(self, 'visit_' + stmt, lambda n: _disallow_stmt(self, n))

def parse_tasklet(self, tasklet_ast: TaskletType):
def parse_tasklet(self, tasklet_ast: TaskletType,
name: Optional[str] = None):
""" Parses the AST of a tasklet and returns the tasklet node, as well as input and output memlets.
:param tasklet_ast: The Tasklet's Python AST to parse.
:param name: Optional name to use as prefix for tasklet.
:return: 3-tuple of (Tasklet node, input memlets, output memlets).
@rtype: Tuple[Tasklet, Dict[str, Memlet], Dict[str, Memlet]]
"""
Expand All @@ -1469,7 +1471,11 @@ def parse_tasklet(self, tasklet_ast: TaskletType):
self.filename)

# Determine tasklet name (either declared as a function or use line #)
name = getattr(tasklet_ast, 'name', 'tasklet_%d' % tasklet_ast.lineno)
if name is not None:
name += '_' + str(tasklet_ast.lineno)
else:
name = getattr(tasklet_ast, 'name',
'tasklet_%d' % tasklet_ast.lineno)

t = self.state.add_tasklet(name,
set(self.inputs.keys()),
Expand Down Expand Up @@ -2624,7 +2630,7 @@ def visit_For(self, node: ast.For):
me, mx = state.add_map(name='%s_%d' % (self.name, node.lineno),
ndrange=params)
# body = SDFG('MapBody')
body, inputs, outputs = self._parse_subprogram('MapBody', node)
body, inputs, outputs = self._parse_subprogram(self.name, node)
tasklet = state.add_nested_sdfg(body, self.sdfg, inputs.keys(),
outputs.keys())
self._add_dependencies(state, tasklet, me, mx, inputs, outputs,
Expand Down Expand Up @@ -2700,7 +2706,7 @@ def _parse_index(self, node: ast.Index):

return indices

def _parse_tasklet(self, state: SDFGState, node: TaskletType):
def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
ttrans = TaskletTransformer(self.defined,
self.sdfg,
state,
Expand All @@ -2710,7 +2716,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType):
scope_vars=self.scope_vars,
variables=self.variables,
accesses=self.accesses)
node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node)
node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name)

# Convert memlets to their actual data nodes
for i in inputs.values():
Expand Down Expand Up @@ -3644,8 +3650,16 @@ def visit_With(self, node, is_async=False):
if funcname == 'dace.tasklet':
# Parse as tasklet
state = self._add_state('with_%d' % node.lineno)
tasklet, inputs, outputs, sdfg_inp, sdfg_out = self._parse_tasklet(
state, node)

# Parse tasklet name
namelist = self.name.split('_')
if len(namelist) > 2: # Remove trailing line and column number
name = '_'.join(namelist[:-2])
else:
name = self.name

tasklet, inputs, outputs, sdfg_inp, sdfg_out = \
self._parse_tasklet(state, node, name)

# Add memlets
self._add_dependencies(state, tasklet, None, None, inputs,
Expand Down
10 changes: 10 additions & 0 deletions dace/runtime/include/dace/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,21 @@ namespace dace
return std::sin(a);
}
template<typename T>
DACE_CONSTEXPR DACE_HDFI T sinh(const T& a)
{
return std::sinh(a);
}
template<typename T>
DACE_CONSTEXPR DACE_HDFI T cos(const T& a)
{
return std::cos(a);
}
template<typename T>
DACE_CONSTEXPR DACE_HDFI T cosh(const T& a)
{
return std::cosh(a);
}
template<typename T>
DACE_CONSTEXPR DACE_HDFI T tan(const T& a)
{
return std::tan(a);
Expand Down
11 changes: 11 additions & 0 deletions dace/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def to_json(self):
except RuntimeError:
tmp['scalar_parameters'] = []

# Location in the SDFG list
self.reset_sdfg_list()
tmp['sdfg_list_id'] = int(self.sdfg_list.index(self))

tmp['attributes']['name'] = self.name

return tmp
Expand Down Expand Up @@ -512,6 +516,13 @@ def remove_data(self, name, validate=True):

del self._arrays[name]

def reset_sdfg_list(self):
if self.parent_sdfg is not None:
self._sdfg_list = self.parent_sdfg.reset_sdfg_list()
else:
self._sdfg_list = list(self.all_sdfgs_recursive())
return self._sdfg_list

def update_sdfg_list(self, sdfg_list):
# TODO: Refactor
sub_sdfg_list = self._sdfg_list
Expand Down
19 changes: 14 additions & 5 deletions diode/diode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dace.codegen import codegen
from diode.DaceState import DaceState
from dace.transformation.optimizer import SDFGOptimizer
from dace.transformation.pattern_matching import Transformation
from dace.graph.nodes import LibraryNode
import inspect
from flask import Flask, Response, request, redirect, url_for, abort, jsonify, send_from_directory, send_file
Expand Down Expand Up @@ -609,20 +610,27 @@ def applySDFGProperties(sdfg, properties, step=None):
return sdfg


def applyOptPath(sdfg, optpath, useGlobalSuffix=True, sdfg_props=[]):
def applyOptPath(sdfg, optpath, useGlobalSuffix=True, sdfg_props=None):
# Iterate over the path, applying the transformations
global_counter = {}
if sdfg_props is None: sdfg_props = []
sdfg_props = sdfg_props or []
step = 0
for x in optpath:
optimizer = SDFGOptimizer(sdfg, inplace=True)
matching = optimizer.get_pattern_matches()

name = x['name']
classname = name[:name.index('$')] if name.find('$') >= 0 else name

transformation = next(t for t in Transformation.extensions().keys() if
t.__name__ == classname)
matching = optimizer.get_pattern_matches(patterns=[transformation])

# Apply properties (will automatically apply by step-matching)
sdfg = applySDFGProperties(sdfg, sdfg_props, step)

for pattern in matching:
name = type(pattern).__name__
tsdfg = sdfg.sdfg_list[pattern.sdfg_id]

if useGlobalSuffix:
if name in global_counter:
Expand All @@ -642,7 +650,7 @@ def applyOptPath(sdfg, optpath, useGlobalSuffix=True, sdfg_props=[]):
dace.serialize.set_properties_from_json(pattern,
x['params']['props'],
context=sdfg)
pattern.apply_pattern(sdfg)
pattern.apply_pattern(tsdfg)

if not useGlobalSuffix:
break
Expand Down Expand Up @@ -877,10 +885,11 @@ def get_transformations(sdfgs):
nodeids = []
properties = []
if p is not None:
sdfg_id = p.sdfg_id
sid = p.state_id
nodes = list(p.subgraph.values())
for n in nodes:
nodeids.append([sid, n])
nodeids.append([sdfg_id, sid, n])

properties = dace.serialize.all_properties_to_json(p)
optimizations.append({
Expand Down
Loading

0 comments on commit 3c33fa9

Please sign in to comment.