diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 2cbd90707b..10b555f761 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4,7 +4,7 @@ import itertools from functools import reduce import re -from typing import Any, Dict, List, Tuple, Union, Callable +from typing import Any, Dict, List, Tuple, Union, Callable, Optional import warnings import dace @@ -1453,9 +1453,11 @@ def __init__( for stmt in _DISALLOWED_STMTS: setattr(self, 'visit_' + stmt, lambda n: _disallow_stmt(self, n)) - def parse_tasklet(self, tasklet_ast: TaskletType): + def parse_tasklet(self, tasklet_ast: TaskletType, + name: Optional[str] = None): """ Parses the AST of a tasklet and returns the tasklet node, as well as input and output memlets. :param tasklet_ast: The Tasklet's Python AST to parse. + :param name: Optional name to use as prefix for tasklet. :return: 3-tuple of (Tasklet node, input memlets, output memlets). @rtype: Tuple[Tasklet, Dict[str, Memlet], Dict[str, Memlet]] """ @@ -1469,7 +1471,11 @@ def parse_tasklet(self, tasklet_ast: TaskletType): self.filename) # Determine tasklet name (either declared as a function or use line #) - name = getattr(tasklet_ast, 'name', 'tasklet_%d' % tasklet_ast.lineno) + if name is not None: + name += '_' + str(tasklet_ast.lineno) + else: + name = getattr(tasklet_ast, 'name', + 'tasklet_%d' % tasklet_ast.lineno) t = self.state.add_tasklet(name, set(self.inputs.keys()), @@ -2624,7 +2630,7 @@ def visit_For(self, node: ast.For): me, mx = state.add_map(name='%s_%d' % (self.name, node.lineno), ndrange=params) # body = SDFG('MapBody') - body, inputs, outputs = self._parse_subprogram('MapBody', node) + body, inputs, outputs = self._parse_subprogram(self.name, node) tasklet = state.add_nested_sdfg(body, self.sdfg, inputs.keys(), outputs.keys()) self._add_dependencies(state, tasklet, me, mx, inputs, outputs, @@ -2700,7 +2706,7 @@ def _parse_index(self, node: ast.Index): return indices - def _parse_tasklet(self, state: SDFGState, node: TaskletType): + def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): ttrans = TaskletTransformer(self.defined, self.sdfg, state, @@ -2710,7 +2716,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType): scope_vars=self.scope_vars, variables=self.variables, accesses=self.accesses) - node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node) + node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name) # Convert memlets to their actual data nodes for i in inputs.values(): @@ -3644,8 +3650,16 @@ def visit_With(self, node, is_async=False): if funcname == 'dace.tasklet': # Parse as tasklet state = self._add_state('with_%d' % node.lineno) - tasklet, inputs, outputs, sdfg_inp, sdfg_out = self._parse_tasklet( - state, node) + + # Parse tasklet name + namelist = self.name.split('_') + if len(namelist) > 2: # Remove trailing line and column number + name = '_'.join(namelist[:-2]) + else: + name = self.name + + tasklet, inputs, outputs, sdfg_inp, sdfg_out = \ + self._parse_tasklet(state, node, name) # Add memlets self._add_dependencies(state, tasklet, None, None, inputs, diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 8ecc146448..e3ce2824b1 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -168,11 +168,21 @@ namespace dace return std::sin(a); } template + DACE_CONSTEXPR DACE_HDFI T sinh(const T& a) + { + return std::sinh(a); + } + template DACE_CONSTEXPR DACE_HDFI T cos(const T& a) { return std::cos(a); } template + DACE_CONSTEXPR DACE_HDFI T cosh(const T& a) + { + return std::cosh(a); + } + template DACE_CONSTEXPR DACE_HDFI T tan(const T& a) { return std::tan(a); diff --git a/dace/sdfg.py b/dace/sdfg.py index e21166bff1..6b30f4437c 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -272,6 +272,10 @@ def to_json(self): except RuntimeError: tmp['scalar_parameters'] = [] + # Location in the SDFG list + self.reset_sdfg_list() + tmp['sdfg_list_id'] = int(self.sdfg_list.index(self)) + tmp['attributes']['name'] = self.name return tmp @@ -512,6 +516,13 @@ def remove_data(self, name, validate=True): del self._arrays[name] + def reset_sdfg_list(self): + if self.parent_sdfg is not None: + self._sdfg_list = self.parent_sdfg.reset_sdfg_list() + else: + self._sdfg_list = list(self.all_sdfgs_recursive()) + return self._sdfg_list + def update_sdfg_list(self, sdfg_list): # TODO: Refactor sub_sdfg_list = self._sdfg_list diff --git a/diode/diode_server.py b/diode/diode_server.py index 60597ff277..0691250559 100755 --- a/diode/diode_server.py +++ b/diode/diode_server.py @@ -7,6 +7,7 @@ from dace.codegen import codegen from diode.DaceState import DaceState from dace.transformation.optimizer import SDFGOptimizer +from dace.transformation.pattern_matching import Transformation from dace.graph.nodes import LibraryNode import inspect from flask import Flask, Response, request, redirect, url_for, abort, jsonify, send_from_directory, send_file @@ -609,20 +610,27 @@ def applySDFGProperties(sdfg, properties, step=None): return sdfg -def applyOptPath(sdfg, optpath, useGlobalSuffix=True, sdfg_props=[]): +def applyOptPath(sdfg, optpath, useGlobalSuffix=True, sdfg_props=None): # Iterate over the path, applying the transformations global_counter = {} - if sdfg_props is None: sdfg_props = [] + sdfg_props = sdfg_props or [] step = 0 for x in optpath: optimizer = SDFGOptimizer(sdfg, inplace=True) - matching = optimizer.get_pattern_matches() + + name = x['name'] + classname = name[:name.index('$')] if name.find('$') >= 0 else name + + transformation = next(t for t in Transformation.extensions().keys() if + t.__name__ == classname) + matching = optimizer.get_pattern_matches(patterns=[transformation]) # Apply properties (will automatically apply by step-matching) sdfg = applySDFGProperties(sdfg, sdfg_props, step) for pattern in matching: name = type(pattern).__name__ + tsdfg = sdfg.sdfg_list[pattern.sdfg_id] if useGlobalSuffix: if name in global_counter: @@ -642,7 +650,7 @@ def applyOptPath(sdfg, optpath, useGlobalSuffix=True, sdfg_props=[]): dace.serialize.set_properties_from_json(pattern, x['params']['props'], context=sdfg) - pattern.apply_pattern(sdfg) + pattern.apply_pattern(tsdfg) if not useGlobalSuffix: break @@ -877,10 +885,11 @@ def get_transformations(sdfgs): nodeids = [] properties = [] if p is not None: + sdfg_id = p.sdfg_id sid = p.state_id nodes = list(p.subgraph.values()) for n in nodes: - nodeids.append([sid, n]) + nodeids.append([sdfg_id, sid, n]) properties = dace.serialize.all_properties_to_json(p) optimizations.append({ diff --git a/diode/webclient/diode.js b/diode/webclient/diode.js index 082ed1999c..b2872c43ef 100644 --- a/diode/webclient/diode.js +++ b/diode/webclient/diode.js @@ -472,17 +472,31 @@ class DIODE_Context_SDFG extends DIODE_Context { }); this.highlighted_elements = []; - let graph = this.renderer_pane.graph; - // The input contains a list of multiple elements for (let x of msg.elements) { - let sid = x[0], nid = x[1]; + let sdfg_id = x[0], sid = x[1], nid = x[2]; let elem = null; + let graph = null; + if (sdfg_id >= 0) + graph = this.renderer_pane.sdfg_list[sdfg_id]; + else + graph = this.renderer_pane.graph; + + // If graph is hidden, skip + if (graph === undefined) + continue; + if (sid == -1) elem = graph.node(nid); - else - elem = graph.node(sid).data.graph.node(nid); - this.highlighted_elements.push(elem); + else { + let state = graph.node(sid); + // If state is hidden, skip + if (state === undefined) + continue; + elem = state.data.graph.node(nid); + } + if (elem !== undefined) + this.highlighted_elements.push(elem); } this.highlighted_elements.forEach(e => { if (e) e.stroke_color = "#D35400"; @@ -591,7 +605,7 @@ class DIODE_Context_SDFG extends DIODE_Context { o = JSON.parse(o); } while (typeof o.sdfg == 'string') { - o.sdfg = JSON.parse(o.sdfg); + o.sdfg = parse_sdfg(o.sdfg); } return o; } @@ -704,10 +718,10 @@ class DIODE_Context_SDFG extends DIODE_Context { into the state. */ let nref = node.element; - let sdfg = node.sdfg; nref.attributes[name] = value; + let sdfg = this.renderer_pane.sdfg; let old = this.getState(); if (old.type == "SDFG") old = sdfg; @@ -1394,31 +1408,34 @@ class DIODE_Context_AvailableTransformations extends DIODE_Context { } }]; - let tmp = () => { + let tmp = (x) => { // Compile after the transformation has been saved this.diode.gatherProjectElementsAndCompile(this, { optpath: named }, { sdfg_over_code: true + }, () => { + this.project().saveSnapshot(x['sdfg_object'], named); + + this.project().request(['update-tfh'], x => { + this.operation_running = false; + }, { + on_timeout: () => { + this.operation_running = false; + } + }); + }, () => { + // On failure + this.operation_running = false; }); }; this.project().request(['sdfg_object'], x => { console.log("Got snapshot", x); - if (typeof (x.sdfg_object) == 'string') + if(typeof(x.sdfg_object) == 'string') x.sdfg_object = JSON.parse(x.sdfg_object); - this.project().saveSnapshot(x['sdfg_object'], named); - - this.project().request(['update-tfh'], x => { - this.operation_running = false; - }, { - on_timeout: () => { - this.operation_running = false; - } - }); - - setTimeout(tmp, 10); + setTimeout(tmp, 10, x); }, {}); } @@ -2811,6 +2828,8 @@ class DIODE_Project { this._listeners = {}; this._closed_windows = []; + + this._blob = null; } clearTransformationHistory() { @@ -3012,41 +3031,47 @@ class DIODE_Project { return JSON.parse(tmp); } - save() { - /* - Saves all elements of this project to its own slot in the local storage - (such that it can be opened again even if the window was closed). - - */ - - let snapshots = this.getTransformationSnapshots(); - if (typeof (snapshots) == 'string') - snapshots = JSON.parse(snapshots); - let y = { - project_id: this._project_id, - data: this._diode.goldenlayout.toConfig(), - snapshots: snapshots, - last_saved: new Date().toLocaleDateString(), - description: "" - }; - let save_val = JSON.stringify(y); + createblob(data) { + var blob = new Blob([data], { + type: 'text/plain' + }); - // The sdfg is not sufficiently unique. - let save_name = prompt("Enter project name"); + // If we are replacing a previously generated file we need to + // manually revoke the object URL to avoid memory leaks. + if (this._blob !== null) { + window.URL.revokeObjectURL(this._blob); + } - window.localStorage.setItem("project_" + save_name, save_val); - - let sp = window.localStorage.getItem("saved_projects"); - if (sp == null) { - sp = []; - } else { - sp = JSON.parse(sp); - } - - sp = [save_name, ...sp]; - window.localStorage.setItem("saved_projects", JSON.stringify(sp)); + this._blob = window.URL.createObjectURL(blob); + return this._blob; + } + save() { + // Save current open file as SDFG + this.request(['sdfg_object'], x => { + let sdfg = x.sdfg_object; + let filename = null; + if (typeof (sdfg) != 'string') { + filename = Object.keys(x.sdfg_object)[0]; + sdfg = stringify_sdfg(Object.values(x.sdfg_object)[0]); + } else { + let sdfg_obj = parse_sdfg(sdfg); + filename = sdfg_obj.attributes.name; + } + filename += '.sdfg'; + + var link = document.createElement('a'); + link.setAttribute('download', filename); + link.href = this.createblob(sdfg); + document.body.appendChild(link); + // wait for the link to be added to the document + window.requestAnimationFrame(function () { + var event = new MouseEvent('click'); + link.dispatchEvent(event); + document.body.removeChild(link); + }); + }); } request(list, callback, options = {}) { @@ -5030,7 +5055,7 @@ class DIODE { // Expand library node REST_request("/dace/api/v1.0/expand/", { sdfg: node.sdfg, - nodeid: [0, node.element.parent_id, node.element.id] + nodeid: [node.sdfg.sdfg_list_id, node.element.parent_id, node.element.id] }, (xhr) => { if (xhr.readyState === 4 && xhr.status === 200) { let resp = parse_sdfg(xhr.response); @@ -5179,6 +5204,9 @@ class DIODE { } else if (x.metatype == "font") { console.warn("Ignoring property type ", x.metatype); return elem; + } else if(x.metatype == "SDFGReferenceProperty") { + // Nothing to display + return elem; } else if (x.metatype == "SubsetProperty") { if (x.value == null) { elem = FormBuilder.createTextInput("prop_" + x.name, (elem) => { @@ -5734,7 +5762,8 @@ class DIODE { } } - gatherProjectElementsAndCompile(calling_context, direct_passing = {}, options = {}) { + gatherProjectElementsAndCompile(calling_context, direct_passing = {}, options = {}, + on_success = undefined, on_failure = undefined) { /* This method collects all available elements that can be used for compilation. @@ -5786,7 +5815,7 @@ class DIODE { let cval = values['input_code']; // Assuming SDFG files start with { - if (cval[0] == '{') { + if (!cis && cval[0] == '{') { let sd = parse_sdfg(cval); values['sdfg_object'] = {}; values['sdfg_object'][sd.attributes.name] = cval; @@ -5811,7 +5840,7 @@ class DIODE { runopts['repetitions'] = 5; // TODO(later): Allow users to configure number runopts['code_is_sdfg'] = cis; runopts['runnercode'] = values['input_code']; - this.compile_and_run(calling_context, options.term_id, cval, values['optpath'], values['sdfg_props'], runopts); + this.compile_and_run(calling_context, options.term_id, cval, values['optpath'], values['sdfg_props'], runopts, on_success, on_failure); } else { let cb = (resp) => { this.replaceOrCreate(['extend-optgraph'], 'AvailableTransformationsComponent', resp, (_) => { @@ -5826,7 +5855,7 @@ class DIODE { { optpath_cb: cb, code_is_sdfg: cis, - }); + }, on_success, on_failure); } } @@ -5834,7 +5863,8 @@ class DIODE { calling_context.project().request(reqlist, on_collected, {timeout: 500, on_timeout: on_collected}); } - compile(calling_context, code, optpath = undefined, sdfg_node_properties = undefined, options = {}) { + compile(calling_context, code, optpath = undefined, sdfg_node_properties = undefined, options = {}, + on_success = undefined, on_failure = undefined) { /* options: .code_is_sdfg: If true, the code parameter is treated as a serialized SDFG @@ -5864,6 +5894,8 @@ class DIODE { if (peek['error'] != undefined) { // There was at least one error - handle all of them this.handleErrors(calling_context, peek); + if (on_failure !== undefined) + on_failure(); } else { // Data is no longer stale this.removeStaleDataButton(); @@ -5875,6 +5907,8 @@ class DIODE { } else { options.optpath_cb(o['compounds']); } + if (on_success !== undefined) + on_success(); } } }); @@ -5997,7 +6031,9 @@ class DIODE { this.addContentItem(newconf); } - compile_and_run(calling_context, terminal_identifier, code, optpath = undefined, sdfg_node_properties = undefined, options = {}) { + compile_and_run(calling_context, terminal_identifier, code, optpath = undefined, + sdfg_node_properties = undefined, options={}, on_success = undefined, + on_failure = undefined) { /* .runnercode: [opt] Code provided with SDFG to invoke the SDFG program. */ @@ -6026,6 +6062,8 @@ class DIODE { } else { alert("Error! Check console"); console.error("Unknown instrumentation mode", remaining_settings['Instrumentation']); + if(on_failure !== undefined) on_failure(); + return; } //post_params['perfmodes'] = ["default", "vectorize", "memop", "cacheop"]; let not = remaining_settings['Number of threads']; @@ -6044,7 +6082,10 @@ class DIODE { if (tmp['error']) { // Normal, users should poll on a different channel now. this.display_current_execution_status(calling_context, terminal_identifier, client_id); - } + if (on_failure !== undefined) + on_failure(); + } else if (on_success !== undefined) + on_success(); } }); }); diff --git a/diode/webclient/main.js b/diode/webclient/main.js index 2699481416..14318e2299 100644 --- a/diode/webclient/main.js +++ b/diode/webclient/main.js @@ -864,6 +864,7 @@ function start_DIODE() { diode.open_diode_settings(); }); diode.addKeyShortcut('r', () => { diode.gatherProjectElementsAndCompile(diode, {}, { sdfg_over_code: true }); }); + diode.addKeyShortcut('s', () => { diode.project().save(); }, false, true); diode.setupEvents(); diff --git a/diode/webclient/renderer.js b/diode/webclient/renderer.js index 8074603805..86ee9823e6 100644 --- a/diode/webclient/renderer.js +++ b/diode/webclient/renderer.js @@ -385,7 +385,7 @@ function calculateNodeSize(sdfg_state, node, ctx) { } // Layout SDFG elements (states, nodes, scopes, nested SDFGs) -function relayout_sdfg(ctx, sdfg) { +function relayout_sdfg(ctx, sdfg, sdfg_list) { let STATE_MARGIN = 4*LINEHEIGHT; // Layout the SDFG as a dagre graph @@ -404,7 +404,7 @@ function relayout_sdfg(ctx, sdfg) { stateinfo.height = LINEHEIGHT; } else { - state_g = relayout_state(ctx, state, sdfg); + state_g = relayout_state(ctx, state, sdfg, sdfg_list); stateinfo = calculateBoundingBox(state_g); } stateinfo.width += 2*STATE_MARGIN; @@ -464,10 +464,13 @@ function relayout_sdfg(ctx, sdfg) { g.width = bb.width; g.height = bb.height; + // Add SDFG to global store + sdfg_list[sdfg.sdfg_list_id] = g; + return g; } -function relayout_state(ctx, sdfg_state, sdfg) { +function relayout_state(ctx, sdfg_state, sdfg, sdfg_list) { // layout the state as a dagre graph let g = new dagre.graphlib.Graph({multigraph: true}); @@ -504,7 +507,7 @@ function relayout_state(ctx, sdfg_state, sdfg) { // Recursively lay out nested SDFGs if (node.type === "NestedSDFG") { - nested_g = relayout_sdfg(ctx, node.attributes.sdfg); + nested_g = relayout_sdfg(ctx, node.attributes.sdfg, sdfg_list); let sdfginfo = calculateBoundingBox(nested_g); node.attributes.layout.width = sdfginfo.width + 2*LINEHEIGHT; node.attributes.layout.height = sdfginfo.height + 2*LINEHEIGHT; @@ -653,6 +656,7 @@ class SDFGRenderer { constructor(sdfg, container, on_mouse_event = null) { // DIODE/SDFV-related fields this.sdfg = sdfg; + this.sdfg_list = {}; // Rendering-related fields this.container = container; @@ -822,7 +826,8 @@ class SDFGRenderer { // Re-layout graph and nested graphs relayout() { - this.graph = relayout_sdfg(this.ctx, this.sdfg); + this.sdfg_list = {}; + this.graph = relayout_sdfg(this.ctx, this.sdfg, this.sdfg_list); this.onresize(); return this.graph;