Skip to content

Commit

Permalink
Merge branch 'master' into simplify-ext
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Mar 10, 2020
2 parents 7998e28 + a927bd6 commit 9d3b855
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 48 deletions.
52 changes: 27 additions & 25 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import warnings

from dace import data, registry
from dace import data, registry, memlet as mm
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.targets.cpp import *
from dace.codegen.targets.target import TargetCodeGenerator, make_absolute, \
Expand Down Expand Up @@ -744,12 +744,18 @@ def process_out_memlets(self,
"Cannot copy memlet without a local connector: {} to {}"
.format(str(edge.src), str(edge.dst)))

could_be_scalar = True
if isinstance(node, nodes.NestedSDFG):
could_be_scalar = not isinstance(node.sdfg.arrays[uconn],
data.Array)

try:
positive_accesses = bool(memlet.num_accesses >= 0)
except TypeError:
positive_accesses = False

if memlet.subset.data_dims() == 0 and positive_accesses:
if (memlet.subset.data_dims() == 0 and positive_accesses
and could_be_scalar):
out_local_name = " __" + uconn
in_local_name = uconn
if not locals_defined:
Expand Down Expand Up @@ -904,11 +910,13 @@ def memlet_view_ctor(self, sdfg, memlet, is_output):
)

def memlet_definition(self,
sdfg,
memlet,
output,
local_name,
sdfg: SDFG,
memlet: mm.Memlet,
output: bool,
local_name: str,
conntype: data.Data = None,
allow_shadowing=False):
could_be_scalar = not conntype or not isinstance(conntype, data.Array)
result = ("auto __%s = " % local_name +
self.memlet_ctor(sdfg, memlet, output) + ";\n")

Expand Down Expand Up @@ -959,7 +967,8 @@ def memlet_definition(self,
DefinedType.Scalar,
allow_shadowing=allow_shadowing)
elif var_type == DefinedType.Pointer:
if memlet.num_accesses == 1 and memlet.subset.num_elements() == 1:
if (memlet.num_accesses == 1 and memlet.subset.num_elements() == 1
and could_be_scalar):
if output:
result += "{} {};".format(memlet_type, local_name)
else:
Expand All @@ -970,7 +979,7 @@ def memlet_definition(self,
DefinedType.Scalar,
allow_shadowing=allow_shadowing)
else:
if memlet.subset.data_dims() == 0:
if memlet.subset.data_dims() == 0 and could_be_scalar:
# Forward ArrayView
result += "auto &{} = __{}.ref<{}>();".format(
local_name, local_name, memlet.veclen)
Expand Down Expand Up @@ -1290,6 +1299,7 @@ def _generate_NestedSDFG(
in_memlet,
False,
vconn,
conntype=node.sdfg.arrays[vconn],
allow_shadowing=True), sdfg, state_id,
node)
for _, uconn, _, _, out_memlet in state_dfg.out_edges(node):
Expand All @@ -1298,11 +1308,13 @@ def _generate_NestedSDFG(
out_code = emit_memlet_reference(self._dispatcher, sdfg,
out_memlet, uconn)
else:
out_code = self.memlet_definition(sdfg,
out_memlet,
True,
uconn,
allow_shadowing=True)
out_code = self.memlet_definition(
sdfg,
out_memlet,
True,
uconn,
conntype=node.sdfg.arrays[uconn],
allow_shadowing=True)

callsite_stream.write(out_code, sdfg, state_id, node)

Expand Down Expand Up @@ -1978,18 +1990,8 @@ def _generate_Reduce(self, sdfg, dfg, state_id, node, function_stream,

# Store back tmpout into the true output
if i == end_braces - 1 and use_tmpout:
# TODO: This is a targeted fix that has to be generalized when
# refactoring code generation. The issue is related to an
# inconsistency on whether an output connector generates a tmp
# scalar variable to be used with __write or a pointer to the
# output array.
scalar_output = True
for r in output_subset:
if r != 0 and r != (0, 0, 1):
scalar_output = False
break
arr = sdfg.arrays[output_memlet.data]
if scalar_output and sdfg.parent_sdfg and not arr.transient:
if (self._dispatcher.defined_vars.get(
output_memlet.data) == DefinedType.Scalar):
out_var = output_memlet.data
else:
out_var = cpp_array_expr(sdfg, output_memlet)
Expand Down
44 changes: 21 additions & 23 deletions dace/codegen/targets/xilinx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections
import itertools
import os
import re

Expand Down Expand Up @@ -286,10 +288,12 @@ def generate_kernel_boilerplate_pre(self, sdfg, state_id, kernel_name,
if kernel_arg:
kernel_args.append(kernel_arg)

scalar_parameters = collections.OrderedDict(scalar_parameters)
symbol_parameters.update(scalar_parameters)
kernel_args += ([
arg.signature(with_types=True, name=argname)
for argname, arg in scalar_parameters
] + symbol_params)
for argname, arg in symbol_parameters.items()
])

# Write kernel signature
kernel_stream.write(
Expand Down Expand Up @@ -326,11 +330,15 @@ def generate_host_function_body(self, sdfg, state, kernel_name, parameters,
symbol_parameters, kernel_stream):

# Just collect all variable names for calling the kernel function
kernel_args = [
p.signature(False, name=name) for is_output, name, p in parameters
]

kernel_args += symbol_parameters.keys()
added = set()
kernel_args = []
for _, name, p in itertools.chain(
parameters,
[(False, k, v) for k, v in symbol_parameters.items()]):
if not isinstance(p, dace.data.Array) and name in added:
continue
added.add(name)
kernel_args.append(p.signature(False, name=name))

kernel_function_name = kernel_name
kernel_file_name = "{}.xclbin".format(kernel_name)
Expand All @@ -353,18 +361,13 @@ def generate_module(self, sdfg, state, name, subgraph, parameters,
state_id = sdfg.node_id(state)
dfg = sdfg.nodes()[state_id]

# Treat scalars and symbols the same, assuming there are no scalar
# outputs
symbol_sigs = [
v.signature(with_types=True, name=k)
for k, v in symbol_parameters.items()
]
symbol_names = symbol_parameters.keys()
kernel_args_call = []
kernel_args_module = []
added = set()

for is_output, pname, p in parameters:
for is_output, pname, p in itertools.chain(
parameters,
[(False, k, v) for k, v in symbol_parameters.items()]):
if isinstance(p, dace.data.Array):
arr_name = "{}_{}".format(pname, "out" if is_output else "in")
kernel_args_call.append(arr_name)
Expand Down Expand Up @@ -393,8 +396,6 @@ def generate_module(self, sdfg, state, name, subgraph, parameters,
p.signature(with_types=False, name=pname))
kernel_args_module.append(
p.signature(with_types=True, name=pname))
kernel_args_call += symbol_names
kernel_args_module += symbol_sigs
module_function_name = "module_" + name
# Unrolling processing elements: if there first scope of the subgraph
# is an unrolled map, generate a processing element for each iteration
Expand Down Expand Up @@ -622,7 +623,9 @@ def generate_host_header(self, sdfg, kernel_function_name, parameters,
kernel_args = []

seen = set()
for is_output, name, arg in parameters:
for is_output, name, arg in itertools.chain(
parameters,
[(False, k, v) for k, v in symbol_parameters.items()]):
if isinstance(arg, dace.data.Array):
kernel_args.append(
arg.signature(with_types=True,
Expand All @@ -634,11 +637,6 @@ def generate_host_header(self, sdfg, kernel_function_name, parameters,
seen.add(name)
kernel_args.append(arg.signature(with_types=True, name=name))

kernel_args += [
v.signature(with_types=True, name=k)
for k, v in symbol_parameters.items()
]

host_code_stream.write(
"""\
// Signature of kernel function (with raw pointers) for argument matching
Expand Down
38 changes: 38 additions & 0 deletions tests/nested_sdfg_scalar_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import dace
import numpy as np


def _construct_sdfg():
sdfg = dace.SDFG('nsstest')
sdfg.add_array('A', [2], dace.float64)
state = sdfg.add_state()

nsdfg = dace.SDFG('nested')
nsdfg.add_array('a', [1], dace.float64)
nsdfg.add_array('b', [1], dace.float64)
nstate = nsdfg.add_state()
nstate.add_mapped_tasklet('m',
dict(i='0'),
dict(inp=dace.Memlet.simple('a', 'i')),
'out = inp * 5.0',
dict(out=dace.Memlet.simple('b', 'i')),
external_edges=True)

r = state.add_read('A')
n = state.add_nested_sdfg(nsdfg, None, {'a'}, {'b'})
w = state.add_write('A')
state.add_edge(r, None, n, 'a', dace.Memlet.simple('A', '1'))
state.add_edge(n, 'b', w, None, dace.Memlet.simple('A', '0'))

return sdfg


def test_nss():
sdfg = _construct_sdfg()
A = np.random.rand(2)
sdfg(A=A)
assert A[0] == A[1] * 5


if __name__ == '__main__':
test_nss()

0 comments on commit 9d3b855

Please sign in to comment.