Skip to content

Commit

Permalink
Symbol redeclaration fix (#1788)
Browse files Browse the repository at this point in the history
If a symbol is passed as an argument to the SDFG and assigned on an
interstate edge, the framecode declares the scalar symbol again. This PR
fixes that issue.

The case: If there is an interstate assignment on a variable, `N=5`, but
if 'N` is also an input argument to the SDFG, then the generated code
does not compile because `N` is redeclared.
  • Loading branch information
ThrudPrimrose authored Dec 16, 2024
1 parent e82870a commit 4963e6b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 6 deletions.
12 changes: 8 additions & 4 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def dispatcher(self):
def preprocess(self, sdfg: SDFG) -> None:
"""
Called before code generation. Used for making modifications on the SDFG prior to code generation.
:note: Post-conditions assume that the SDFG will NOT be changed after this point.
:param sdfg: The SDFG to modify in-place.
"""
Expand Down Expand Up @@ -900,6 +900,8 @@ def generate_code(self,
# Allocate outer-level transients
self.allocate_arrays_in_scope(sdfg, sdfg, sdfg, global_stream, callsite_stream)

outside_symbols = sdfg.arglist() if is_top_level else set()

# Define constants as top-level-allocated
for cname, (ctype, _) in sdfg.constants_prop.items():
if isinstance(ctype, data.Array):
Expand Down Expand Up @@ -952,10 +954,12 @@ def generate_code(self,
and config.Config.get('compiler', 'fpga', 'vendor').lower() == 'intel_fpga'):
# Emit OpenCL type
callsite_stream.write(f'{isvarType.ocltype} {isvarName};\n', sdfg)
self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype)
else:
callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg)
self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype)

# If the variable is passed as an input argument to the SDFG, do not need to declare it
if isvarName not in outside_symbols:
callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg)
self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype)
callsite_stream.write('\n', sdfg)

#######################################################################
Expand Down
5 changes: 3 additions & 2 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def validate_control_flow_region(sdfg: 'SDFG',

def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context: bool):
""" Verifies the correctness of an SDFG by applying multiple tests.
:param sdfg: The SDFG to verify.
:param references: An optional set keeping seen IDs for object
miscopy validation.
Expand Down Expand Up @@ -331,14 +331,15 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
for sym in desc.free_symbols:
symbols[str(sym)] = sym.dtype
validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context)


except InvalidSDFGError as ex:
# If the SDFG is invalid, save it
fpath = os.path.join('_dacegraphs', 'invalid.sdfgz')
sdfg.save(fpath, exception=ex, compress=True)
ex.path = fpath
raise


def _accessible(sdfg: 'dace.sdfg.SDFG', container: str, context: Dict[str, bool]):
"""
Helper function that returns False if a data container cannot be accessed in the current SDFG context.
Expand Down
51 changes: 51 additions & 0 deletions tests/interstate_assignment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Dict
import dace

N = dace.symbol("N")

def _get_interstate_dependent_sdfg(assignments: Dict, symbols_at_start=False):
sdfg = dace.SDFG("interstate_dependent")
for k in assignments:
sdfg.add_symbol(k, dace.int32)

s1 = sdfg.add_state("s1")
s2 = sdfg.add_state("s2")

if not symbols_at_start:
s0 = sdfg.add_state("s0")
pre_assignments = dict()
for k,v in assignments.items():
pre_assignments[k] = v*2
sdfg.add_edge(s0, s1, dace.InterstateEdge(None, assignments=pre_assignments))

for sid, s in [("1", s1), ("2", s2)]:
sdfg.add_array(f"array{sid}", (N, ) , dace.int32, storage=dace.StorageType.CPU_Heap, transient=True)
an = s.add_access(f"array{sid}")
an2 = s.add_access(f"array{sid}")
t = s.add_tasklet(f"tasklet{sid}", {"_in"}, {"_out"}, "_out = _in * 2")
map_entry, map_exit = s.add_map(f"map{sid}", {"i":dace.subsets.Range([(0,N-1,1)])})
for m in [map_entry, map_exit]:
m.add_in_connector(f"IN_array{sid}")
m.add_out_connector(f"OUT_array{sid}")
s.add_edge(an, None, map_entry, f"IN_array{sid}", dace.memlet.Memlet(f"array{sid}[0:N]"))
s.add_edge(map_entry, f"OUT_array{sid}", t, "_in", dace.memlet.Memlet(f"array{sid}[i]"))
s.add_edge(t, "_out", map_exit, f"IN_array{sid}", dace.memlet.Memlet(f"array{sid}[i]"))
s.add_edge(map_exit, f"OUT_array{sid}", an2, None, dace.memlet.Memlet(f"array{sid}[0:N]"))

sdfg.add_edge(s1, s2, dace.InterstateEdge(None, assignments=assignments))
sdfg.validate()
return sdfg

def test_interstate_assignment():
sdfg = _get_interstate_dependent_sdfg({"N": 5}, False)
sdfg.validate()
sdfg()

def test_interstate_assignment_on_sdfg_input():
sdfg = _get_interstate_dependent_sdfg({"N": 5}, True)
sdfg.validate()
sdfg(N=10)

if __name__ == "__main__":
test_interstate_assignment()
test_interstate_assignment_on_sdfg_input()

0 comments on commit 4963e6b

Please sign in to comment.