From 2c7a4d9455bd86ceb3f4e33509ec8383dda1a2b3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 23 Oct 2024 14:19:45 +0200 Subject: [PATCH] Added some more tests to the refined access transformation. It is important that one test has to be disabled. This is because of a fundamental flaw in how refine nested access works. --- .../refine_nested_access_test.py | 112 +++++++++++++++--- 1 file changed, 95 insertions(+), 17 deletions(-) diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index 3e5b97e9a0..5e86566143 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -156,7 +156,10 @@ def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int): assert np.allclose(ref, val) -def _make_rna_read_and_write_set_sdfg(diff_in_out: bool) -> dace.SDFG: +def _make_rna_read_and_write_set_sdfg( + diff_in_out: bool, + canonical_memlet_direction: bool, +) -> dace.SDFG: """Generates the SDFG for the `test_rna_read_and_write_sets_*()` tests. If `diff_in_out` is `False` then the output is also used as temporary storage @@ -167,22 +170,30 @@ def _make_rna_read_and_write_set_sdfg(diff_in_out: bool) -> dace.SDFG: If `diff_in_out` is true, then a different storage container, which is classified as output, is used as temporary storage. + By setting `canonical_memlet_direction` to `False` the function will generate + a Memlet with a non canonical direction. This affects the Memlet that copies + data from `A` into the intermediate, which then might be filtered out, but only + in cases `diff_in_out` is `True`. + This test was added during [PR#1678](https://github.com/spcl/dace/pull/1678). """ - def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: + def _make_nested_sdfg( + diff_in_out: bool, + canonical_memlet_direction: bool, + ) -> dace.SDFG: sdfg = dace.SDFG("inner_sdfg") state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", dtype=dace.float64, shape=(2,), transient=False) sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False) A = state.add_access("A") - T1_input = state.add_access("T1") + T1_output = state.add_access("T1") if diff_in_out: sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False) - T1_output = state.add_access("T2") + T1_input = state.add_access("T2") else: - T1_output = state.add_access("T1") + T1_input = state.add_access("T1") tsklt = state.add_tasklet( "comp", @@ -192,10 +203,12 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: ) state.add_edge(A, None, tsklt, "__in1", dace.Memlet("A[1]")) - # An alternative would be to write to a different location here. - # Then, the data would be added to the access node. - state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]")) - state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet("T1[0]")) + if canonical_memlet_direction: + state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]")) + else: + state.add_edge(A, None, T1_input, None, dace.Memlet(data=T1_input.data, subset="0", other_subset="0")) + + state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet(T1_input.data + "[0]")) state.add_edge(tsklt, "__out", T1_output, None, dace.Memlet(T1_output.data + "[1]")) return sdfg @@ -208,7 +221,7 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: A = state.add_access("A") T1 = state.add_access("T1") - nested_sdfg = _make_nested_sdfg(diff_in_out) + nested_sdfg = _make_nested_sdfg(diff_in_out, canonical_memlet_direction) nsdfg = state.add_nested_sdfg( nested_sdfg, @@ -227,12 +240,11 @@ def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: return sdfg -def test_rna_read_and_write_sets_doule_use(): - +def test_rna_read_and_write_sets_doule_use_canon_memlet(): # Because the same storage is used in multiple places (in the same data flow graph) # it will not be picked up by the transformation. Because of its dependency # on `SDFGState.read_and_write_sets()`. - sdfg = _make_rna_read_and_write_set_sdfg(False) + sdfg = _make_rna_read_and_write_set_sdfg(False, True) nb_applied = sdfg.apply_transformations_repeated( [RefineNestedAccess], validate=True, @@ -240,11 +252,63 @@ def test_rna_read_and_write_sets_doule_use(): ) assert nb_applied == 0 + # Test if the SDFG is not changed. + args = { + "A": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T2": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T1": np.zeros(2, dtype=np.float64), + } + ref = args["A"][0] + args["A"][1] + sdfg(**args) + res = args["T1"][1] -def test_rna_read_and_write_sets_different_storage(): + assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'." + + + +def test_rna_read_and_write_sets_doule_use_non_canon_memlet(): + # Because the same storage is used in multiple places (in the same data flow graph) + # it will not be picked up by the transformation. Because of its dependency + # on `SDFGState.read_and_write_sets()`. + sdfg = _make_rna_read_and_write_set_sdfg(False, False) + nb_applied = sdfg.apply_transformations( + [RefineNestedAccess], + validate=True, + validate_all=True, + ) + assert nb_applied == 0 + +def test_rna_read_and_write_sets_different_storage_canon_memlet(): # Because different storage is used the transformation will apply here. - sdfg = _make_rna_read_and_write_set_sdfg(True) + sdfg = _make_rna_read_and_write_set_sdfg(True, True) + nb_applied = sdfg.apply_transformations_repeated( + [RefineNestedAccess], + validate=True, + validate_all=True, + ) + assert nb_applied > 0 + + # Test if the SDFG is not changed. + args = { + "A": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T2": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T1": np.zeros(2, dtype=np.float64), + } + ref = args["A"][0] + args["A"][1] + sdfg(**args) + res = args["T1"][1] + + assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'." + + +def test_rna_read_and_write_sets_different_storage_non_canon_memlet(): + # NOTE: This transformation should actually apply and work. However for some reasons the `RefineNestedAccess` + # transformation is modifying the access to `A` in this mode (if the Memlet is canonical it works). + # The reason is a fundamental flaw of the transformation. For that reason this test is disabled. + return + + sdfg = _make_rna_read_and_write_set_sdfg(True, False) nb_applied = sdfg.apply_transformations_repeated( [RefineNestedAccess], validate=True, @@ -252,10 +316,24 @@ def test_rna_read_and_write_sets_different_storage(): ) assert nb_applied > 0 + # Test if the SDFG is not changed. + args = { + "A": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T2": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T1": np.zeros(2, dtype=np.float64), + } + ref = args["A"][0] + args["A"][1] + sdfg(**args) + res = args["T1"][1] + + assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'." + if __name__ == '__main__': test_refine_dataflow() test_refine_interstate() test_free_symbols_only_by_indices() - test_rna_read_and_write_sets_different_storage() - test_rna_read_and_write_sets_doule_use() + test_rna_read_and_write_sets_different_storage_canon_memlet() + test_rna_read_and_write_sets_different_storage_non_canon_memlet() + test_rna_read_and_write_sets_doule_use_canon_memlet() + test_rna_read_and_write_sets_doule_use_non_canon_memlet()