Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to Read and Write Sets #1678

Merged
merged 19 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
145c0ea
Added tests for the `_read_and_write_sets()`.
philip-paul-mueller Oct 11, 2024
38e748b
Added the fix from my MapFusion PR.
philip-paul-mueller Oct 11, 2024
3748c03
Now made `read_and_write_sets()` fully adhere to their own definition.
philip-paul-mueller Oct 11, 2024
3ab4bf3
Updated a test for the `PruneConnectors` transformation.
philip-paul-mueller Oct 11, 2024
b4feddf
Added code to `test_more_than_a_map` to ensure that the transformatio…
philip-paul-mueller Oct 11, 2024
e1c25b2
Merge remote-tracking branch 'spcl/master' into read-write-sets
philip-paul-mueller Oct 14, 2024
70fa3db
Added the new memlet creation syntax.
philip-paul-mueller Oct 14, 2024
b187a82
Modified some comments to make them clearer.
philip-paul-mueller Oct 14, 2024
9c6cb6c
Modified the `tests/transformations/move_loop_into_map_test.py::test_…
philip-paul-mueller Oct 14, 2024
b5fc16f
Merge branch 'master' into read-write-sets
philip-paul-mueller Oct 22, 2024
b7fe242
Added a test to highlights the error.
philip-paul-mueller Oct 22, 2024
b546b07
I now removed the filtering inside the read and write set.
philip-paul-mueller Oct 22, 2024
ae20590
Fixed `state_test.py::test_read_and_write_set_filter`.
philip-paul-mueller Oct 23, 2024
db211fa
Fixed the `state_test.py::test_read_write_set` test.
philip-paul-mueller Oct 23, 2024
570437b
Fixed the `state_test.py::test_read_write_set_y_formation` test.
philip-paul-mueller Oct 23, 2024
cb80f0b
Fixed `move_loop_into_map_test.py::MoveLoopIntoMapTest::test_more_tha…
philip-paul-mueller Oct 23, 2024
b704a43
Fixed `prune_connectors_test.py::test_read_write_*`.
philip-paul-mueller Oct 23, 2024
f74d6e8
General improvements to some tests.
philip-paul-mueller Oct 23, 2024
e103924
Updated `refine_nested_access_test.py::test_rna_read_and_write_sets_d…
philip-paul-mueller Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,51 +745,90 @@ def update_if_not_none(dic, update):

return defined_syms


def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, List[Subset]]]:
"""
Determines what data is read and written in this subgraph, returning
dictionaries from data containers to all subsets that are read/written.
"""
from dace.sdfg import utils # Avoid cyclic import

# Ensures that the `{src,dst}_subset` are properly set.
# TODO: find where the problems are
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very ugly hack, but I can replicate / encounter the issue too. I am fine with leaving this in, but please make sure there is an issue that keeps track of this TODO somewhere.

for edge in self.edges():
edge.data.try_initialize(self.sdfg, self, edge)

read_set = collections.defaultdict(list)
write_set = collections.defaultdict(list)
from dace.sdfg import utils # Avoid cyclic import
subgraphs = utils.concurrent_subgraphs(self)
for sg in subgraphs:
rs = collections.defaultdict(list)
ws = collections.defaultdict(list)

for subgraph in utils.concurrent_subgraphs(self):
subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph.
subgraph_write_set = collections.defaultdict(list)
# Traverse in topological order, so data that is written before it
# is read is not counted in the read set
for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()):
if isinstance(n, nd.AccessNode):
in_edges = sg.in_edges(n)
out_edges = sg.out_edges(n)
# Filter out memlets which go out but the same data is written to the AccessNode by another memlet
for out_edge in list(out_edges):
for in_edge in list(in_edges):
if (in_edge.data.data == out_edge.data.data
and in_edge.data.dst_subset.covers(out_edge.data.src_subset)):
out_edges.remove(out_edge)
break

for e in in_edges:
# skip empty memlets
if e.data.is_empty():
continue
# Store all subsets that have been written
ws[n.data].append(e.data.subset)
for e in out_edges:
# skip empty memlets
if e.data.is_empty():
continue
rs[n.data].append(e.data.subset)
# Union all subgraphs, so an array that was excluded from the read
# set because it was written first is still included if it is read
# in another subgraph
for data, accesses in rs.items():
# TODO: This only works if every data descriptor is only once in a path.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the failure mode here if it appears multiple times?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took my some time to understand it again, essentially each access node is processed individually and all are combined.
An example: Assume the first time the data is encountered the range 0:10 is written and read, so the read set is empty and the write set is 0:10 for that AN.
The second time, the subset 10:20 is written and 2:6 is read, therefore the read set is 2:6 and the write set is 10:20.
The final write set is {0:10, 0:20} and the read set is {0:6}.
However, in this case it could the read set could be determined as {}.

But the more I think about this could actually catch some edge case were this would be incorrect.
In the end it is an over approximation.

I turned the TODO into a NOTE and added a better description.

for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()):
if not isinstance(n, nd.AccessNode):
# Read and writes can only be done through access nodes,
# so ignore every other node.
continue

# Get a list of all incoming (writes) and outgoing (reads) edges of the
# access node, ignore all empty memlets as they do not carry any data.
in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()]
out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()]

# Extract the subsets that describes where we read and write the data
# and store them for the later filtering.
# NOTE: In certain cases the corresponding subset might be None, in this case
# we assume that the whole array is written, which is the default behaviour.
ac_desc = n.desc(self.sdfg)
ac_size = ac_desc.total_size
in_subsets = dict()
for in_edge in in_edges:
# Ensure that if the destination subset is not given, our assumption, that the
# whole array is written to, is valid, by testing if the memlet transfers the
# whole array.
assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size)
in_subsets[in_edge] = (
sbs.Range.from_array(ac_desc)
if in_edge.data.dst_subset is None
else in_edge.data.dst_subset
)
out_subsets = dict()
for out_edge in out_edges:
assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size)
out_subsets[out_edge] = (
sbs.Range.from_array(ac_desc)
if out_edge.data.src_subset is None
else out_edge.data.src_subset
)

# If a memlet reads a particular region of data from the access node and there
# exists a memlet at the same access node that writes to the same region, then
# this read is ignored, and not included in the final read set, but only
# accounted fro in the write set. See also note below.
# TODO: Handle the case when multiple disjoint writes are needed to cover the read.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume from the code that the failure mode here is 'conservative', i.e., we overapproximate something as a read if we have disjoint writes that would cover it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
To be clear, lets say e1 writes 0:10 and e1 writes 10:20, but at the same node e3 reads 5:15 then it would still be considered as a read and a write and not just as a write.
In that sense it is an improvement.

I made the description a bit clearer.

for out_edge in list(out_edges):
for in_edge in in_edges:
if in_subsets[in_edge].covers(out_subsets[out_edge]):
out_edges.remove(out_edge)
break

# Update the read and write sets of the subgraph.
if in_edges:
subgraph_write_set[n.data].extend(in_subsets.values())
if out_edges:
subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges)

# Add the subgraph's read and write set to the final ones.
for data, accesses in subgraph_read_set.items():
read_set[data] += accesses
for data, accesses in ws.items():
for data, accesses in subgraph_write_set.items():
write_set[data] += accesses
return read_set, write_set

return copy.deepcopy((read_set, write_set))


def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
"""
Expand Down
84 changes: 84 additions & 0 deletions tests/sdfg/state_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace import subsets as sbs
from dace.transformation.helpers import find_sdfg_control_flow


Expand Down Expand Up @@ -43,6 +44,7 @@ def test_read_write_set_y_formation():

assert 'B' not in state.read_and_write_sets()[0]


def test_deepcopy_state():
N = dace.symbol('N')

Expand All @@ -58,6 +60,86 @@ def double_loop(arr: dace.float32[N]):
sdfg.validate()


def test_read_and_write_set_filter():
sdfg = dace.SDFG('graph')
state = sdfg.add_state('state')
sdfg.add_array('A', [2, 2], dace.float64)
sdfg.add_scalar('B', dace.float64)
sdfg.add_array('C', [2, 2], dace.float64)
A, B, C = (state.add_access(name) for name in ('A', 'B', 'C'))

state.add_nedge(
A,
B,
dace.Memlet("B[0] -> [0, 0]"),
)
state.add_nedge(
B,
C,
dace.Memlet("C[1, 1] -> [0]"),
)
state.add_nedge(
B,
C,
dace.Memlet("B[0] -> [0, 0]"),
)
sdfg.validate()

expected_reads = {
"A": [sbs.Range.from_string("0, 0")],
}
expected_writes = {
"B": [sbs.Range.from_string("0")],
"C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_read_and_write_set_selection():
sdfg = dace.SDFG('graph')
state = sdfg.add_state('state')
sdfg.add_array('A', [2, 2], dace.float64)
sdfg.add_scalar('B', dace.float64)
A, B = (state.add_access(name) for name in ('A', 'B'))

state.add_nedge(
A,
B,
dace.Memlet("A[0, 0]"),
)
sdfg.validate()

expected_reads = {
"A": [sbs.Range.from_string("0, 0")],
}
expected_writes = {
"B": [sbs.Range.from_string("0")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_add_mapped_tasklet():
sdfg = dace.SDFG("test_add_mapped_tasklet")
state = sdfg.add_state(is_start_block=True)
Expand All @@ -82,6 +164,8 @@ def test_add_mapped_tasklet():


if __name__ == '__main__':
test_read_and_write_set_selection()
test_read_and_write_set_filter()
test_read_write_set()
test_read_write_set_y_formation()
test_deepcopy_state()
Expand Down
22 changes: 21 additions & 1 deletion tests/transformations/move_loop_into_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dace
from dace.transformation.interstate import MoveLoopIntoMap
import unittest
import copy
import numpy as np

I = dace.symbol("I")
Expand Down Expand Up @@ -170,7 +171,26 @@ def test_more_than_a_map(self):
body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr))
body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr))
sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1')
count = sdfg.apply_transformations(MoveLoopIntoMap)

sdfg_args_ref = {
"A": np.array(np.random.rand(3, 3), dtype=np.float64),
"B": np.array(np.random.rand(3, 3), dtype=np.float64),
"out": np.array(np.random.rand(3, 3), dtype=np.float64),
}
sdfg_args_res = copy.deepcopy(sdfg_args_ref)

# Perform the reference execution
sdfg(**sdfg_args_ref)

# Apply the transformation and execute the SDFG again.
count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True)
sdfg(**sdfg_args_res)

for name in sdfg_args_ref.keys():
self.assertTrue(
np.allclose(sdfg_args_ref[name], sdfg_args_res[name]),
f"Miss match for {name}",
)
self.assertFalse(count > 0)

def test_more_than_a_map_1(self):
Expand Down
9 changes: 5 additions & 4 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _make_read_write_sdfg(

Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B`
will either be associated to `inner_A` (`True`) or `inner_B` (`False`).
This choice has consequences on if the transformation can apply or not.

Notes:
This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643),
Expand Down Expand Up @@ -421,18 +420,20 @@ def test_prune_connectors_with_dependencies():


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
sdfg, nsdfg = _make_read_write_sdfg(True)

assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)
sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True)


def test_read_write_2():
# Because the memlet is not conforming, we can not apply the transformation.
# In previous versions of DaCe the transformation could not be applied in the
# case of a not conforming Memlet.
# See [PR#1678](https://github.com/spcl/dace/pull/1678)
sdfg, nsdfg = _make_read_write_sdfg(False)

assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)
assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)
sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True)


if __name__ == "__main__":
Expand Down
Loading