Skip to content

Commit

Permalink
Added some more tests to the refined access transformation.
Browse files Browse the repository at this point in the history
It is important that one test has to be disabled.
This is because of a fundamental flaw in how refine nested access works.
  • Loading branch information
philip-paul-mueller committed Oct 23, 2024
1 parent bd79fc6 commit 2c7a4d9
Showing 1 changed file with 95 additions and 17 deletions.
112 changes: 95 additions & 17 deletions tests/transformations/refine_nested_access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int):
assert np.allclose(ref, val)


def _make_rna_read_and_write_set_sdfg(diff_in_out: bool) -> dace.SDFG:
def _make_rna_read_and_write_set_sdfg(
diff_in_out: bool,
canonical_memlet_direction: bool,
) -> dace.SDFG:
"""Generates the SDFG for the `test_rna_read_and_write_sets_*()` tests.
If `diff_in_out` is `False` then the output is also used as temporary storage
Expand All @@ -167,22 +170,30 @@ def _make_rna_read_and_write_set_sdfg(diff_in_out: bool) -> dace.SDFG:
If `diff_in_out` is true, then a different storage container, which is classified
as output, is used as temporary storage.
By setting `canonical_memlet_direction` to `False` the function will generate
a Memlet with a non canonical direction. This affects the Memlet that copies
data from `A` into the intermediate, which then might be filtered out, but only
in cases `diff_in_out` is `True`.
This test was added during [PR#1678](https://github.com/spcl/dace/pull/1678).
"""

def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG:
def _make_nested_sdfg(
diff_in_out: bool,
canonical_memlet_direction: bool,
) -> dace.SDFG:
sdfg = dace.SDFG("inner_sdfg")
state = sdfg.add_state(is_start_block=True)
sdfg.add_array("A", dtype=dace.float64, shape=(2,), transient=False)
sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False)

A = state.add_access("A")
T1_input = state.add_access("T1")
T1_output = state.add_access("T1")
if diff_in_out:
sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False)
T1_output = state.add_access("T2")
T1_input = state.add_access("T2")
else:
T1_output = state.add_access("T1")
T1_input = state.add_access("T1")

tsklt = state.add_tasklet(
"comp",
Expand All @@ -192,10 +203,12 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG:
)

state.add_edge(A, None, tsklt, "__in1", dace.Memlet("A[1]"))
# An alternative would be to write to a different location here.
# Then, the data would be added to the access node.
state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]"))
state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet("T1[0]"))
if canonical_memlet_direction:
state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]"))
else:
state.add_edge(A, None, T1_input, None, dace.Memlet(data=T1_input.data, subset="0", other_subset="0"))

state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet(T1_input.data + "[0]"))
state.add_edge(tsklt, "__out", T1_output, None, dace.Memlet(T1_output.data + "[1]"))
return sdfg

Expand All @@ -208,7 +221,7 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG:
A = state.add_access("A")
T1 = state.add_access("T1")

nested_sdfg = _make_nested_sdfg(diff_in_out)
nested_sdfg = _make_nested_sdfg(diff_in_out, canonical_memlet_direction)

nsdfg = state.add_nested_sdfg(
nested_sdfg,
Expand All @@ -227,35 +240,100 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG:
return sdfg


def test_rna_read_and_write_sets_doule_use():

def test_rna_read_and_write_sets_doule_use_canon_memlet():
# Because the same storage is used in multiple places (in the same data flow graph)
# it will not be picked up by the transformation. Because of its dependency
# on `SDFGState.read_and_write_sets()`.
sdfg = _make_rna_read_and_write_set_sdfg(False)
sdfg = _make_rna_read_and_write_set_sdfg(False, True)
nb_applied = sdfg.apply_transformations_repeated(
[RefineNestedAccess],
validate=True,
validate_all=True,
)
assert nb_applied == 0

# Test if the SDFG is not changed.
args = {
"A": np.array(np.random.rand(2), dtype=np.float64, copy=True),
"T2": np.array(np.random.rand(2), dtype=np.float64, copy=True),
"T1": np.zeros(2, dtype=np.float64),
}
ref = args["A"][0] + args["A"][1]
sdfg(**args)
res = args["T1"][1]

def test_rna_read_and_write_sets_different_storage():
assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'."



def test_rna_read_and_write_sets_doule_use_non_canon_memlet():
# Because the same storage is used in multiple places (in the same data flow graph)
# it will not be picked up by the transformation. Because of its dependency
# on `SDFGState.read_and_write_sets()`.
sdfg = _make_rna_read_and_write_set_sdfg(False, False)
nb_applied = sdfg.apply_transformations(
[RefineNestedAccess],
validate=True,
validate_all=True,
)
assert nb_applied == 0


def test_rna_read_and_write_sets_different_storage_canon_memlet():
# Because different storage is used the transformation will apply here.
sdfg = _make_rna_read_and_write_set_sdfg(True)
sdfg = _make_rna_read_and_write_set_sdfg(True, True)
nb_applied = sdfg.apply_transformations_repeated(
[RefineNestedAccess],
validate=True,
validate_all=True,
)
assert nb_applied > 0

# Test if the SDFG is not changed.
args = {
"A": np.array(np.random.rand(2), dtype=np.float64, copy=True),
"T2": np.array(np.random.rand(2), dtype=np.float64, copy=True),
"T1": np.zeros(2, dtype=np.float64),
}
ref = args["A"][0] + args["A"][1]
sdfg(**args)
res = args["T1"][1]

assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'."


def test_rna_read_and_write_sets_different_storage_non_canon_memlet():
# NOTE: This transformation should actually apply and work. However for some reasons the `RefineNestedAccess`
# transformation is modifying the access to `A` in this mode (if the Memlet is canonical it works).
# The reason is a fundamental flaw of the transformation. For that reason this test is disabled.
return

sdfg = _make_rna_read_and_write_set_sdfg(True, False)
nb_applied = sdfg.apply_transformations_repeated(
[RefineNestedAccess],
validate=True,
validate_all=True,
)
assert nb_applied > 0

# Test if the SDFG is not changed.
args = {
"A": np.array(np.random.rand(2), dtype=np.float64, copy=True),
"T2": np.array(np.random.rand(2), dtype=np.float64, copy=True),
"T1": np.zeros(2, dtype=np.float64),
}
ref = args["A"][0] + args["A"][1]
sdfg(**args)
res = args["T1"][1]

assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'."


if __name__ == '__main__':
test_refine_dataflow()
test_refine_interstate()
test_free_symbols_only_by_indices()
test_rna_read_and_write_sets_different_storage()
test_rna_read_and_write_sets_doule_use()
test_rna_read_and_write_sets_different_storage_canon_memlet()
test_rna_read_and_write_sets_different_storage_non_canon_memlet()
test_rna_read_and_write_sets_doule_use_canon_memlet()
test_rna_read_and_write_sets_doule_use_non_canon_memlet()

0 comments on commit 2c7a4d9

Please sign in to comment.