Skip to content

Commit

Permalink
Pathed RefineNestedAccess transformation.
Browse files Browse the repository at this point in the history
The main addition is a patch to detect some cases that the RefineNestedAccess Transformation can not handle.
HOWEVER IT IS IMPORTANT TO NOTE, THAT THIS DOES NOT SOLVES ALL OF ITS PROBLEMS, SEE PREVOUS COMMITS.
  • Loading branch information
philip-paul-mueller committed Oct 23, 2024
1 parent 2c7a4d9 commit bf5e4aa
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,23 @@ def _candidates(
read_set, write_set = nstate.read_and_write_sets()
for e in nstate.in_edges(dnode):
if e.data.data not in write_set:
# Skip data which is not in the read and write set of the state -> there also won't be a
# connector
# Skip data which is not in the write set of the state -> there also won't be a connector
continue
# If more than one unique element detected, remove from
# candidates

# NOTE: This test is needed to exclude "non canonical Memlets", that can not be handled by
# this transformation. The main issue is that the transformation unconditionally calls
# `.subset` without checking if this set really refers to the correct subset.
# This fix is essentially needed to make
# `tests/transformations/refine_nested_access_test.py::test_rna_read_and_write_sets_doule_use_non_canon_memlet`
# pass. The same scanning is applied to the `in_candidate`. It is however, not enough to make
# `tests/transformations/refine_nested_access_test.py::test_rna_read_and_write_sets_different_storage_non_canon_memlet` pass.
# TODO: Fix this correctly.
if e.data.dst_subset is not e.data.subset:
return ({}, {})
if e.data.data != dnode.data:
return ({}, {})

# If more than one unique element detected, remove from candidates
if e.data.data in out_candidates:
memlet, ns, indices = out_candidates[e.data.data]
# Try to find dimensions in which there is a mismatch
Expand All @@ -944,15 +956,23 @@ def _candidates(
if len(indices) == 0:
ignore.add(e.data.data)
out_candidates[e.data.data] = (memlet, ns, indices)
continue
out_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset))))
else:
out_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset))))

for e in nstate.out_edges(dnode):
if e.data.data not in read_set:
# Skip data which is not in the read and write set of the state -> there also won't be a
# connector
# Skip data which is not in the read set of the state -> there also won't be a connector
# TODO: This is technically not enough, because the read set might be filtered.
# See `SDFGState.read_and_write_sets()` for more.
continue
# If more than one unique element detected, remove from
# candidates

# See note above.
if e.data.src_subset is not e.data.subset:
return ({}, {})
if e.data.data != dnode.data:
return ({}, {})

# If more than one unique element detected, remove from candidates
if e.data.data in in_candidates:
memlet, ns, indices = in_candidates[e.data.data]
# Try to find dimensions in which there is a mismatch
Expand All @@ -963,8 +983,8 @@ def _candidates(
if len(indices) == 0:
ignore.add(e.data.data)
in_candidates[e.data.data] = (memlet, ns, indices)
continue
in_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset))))
else:
in_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset))))

# Check read memlets in interstate edges for candidates
for e in nsdfg.sdfg.edges():
Expand All @@ -979,8 +999,8 @@ def _candidates(
if len(indices) == 0:
ignore.add(m.data)
in_candidates[m.data] = (memlet, ns, indices)
continue
in_candidates[m.data] = (m, None, set(range(len(m.subset))))
else:
in_candidates[m.data] = (m, None, set(range(len(m.subset))))

# Check in/out candidates
for cand in in_candidates.keys() & out_candidates.keys():
Expand Down

0 comments on commit bf5e4aa

Please sign in to comment.