Skip to content

Commit

Permalink
When checking for the "array usage" criteria that can prevent
Browse files Browse the repository at this point in the history
map-fusion, check only within the current state. Otherwise, any "use" of
the array _globally_ (i.e., in the entire SDFG) will cancel the fusion.
  • Loading branch information
pratyai committed Oct 17, 2024
1 parent bbe8132 commit c0906da
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False):
intermediate_data.add(dst.data)

# If array is used anywhere else in this state.
num_occurrences = len([n for n in sdfg.data_nodes() if n.data == dst.data])
num_occurrences = len([n for n in graph.data_nodes() if n.data == dst.data])
if num_occurrences > 1:
return False
else:
Expand Down
33 changes: 33 additions & 0 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,38 @@ def test_fusion_simple():
assert diff <= 1e-3


def fusion_twostate():
"""Same effect as applying `fusion()` twice."""
st0 = fusion.to_sdfg(simplify=True, validate=True)
st0.start_block.label = 'st0' # Rename to avoid name conflict
st0.apply_transformations(dace.transformation.interstate.StateFusionExtended, validate_all=True)

# Construct the second block
st1 = fusion.to_sdfg(simplify=True, validate=True)
st1.start_block.label = 'st1' # Rename to avoid name conflict
st1.apply_transformations(dace.transformation.interstate.StateFusionExtended, validate_all=True)

st0.add_edge(st0.start_state, st1.start_state, dace.InterstateEdge())
return st0


def test_fusion_twostate():
sdfg = fusion_twostate()
sdfg.save(os.path.join('_dacegraphs', 'before-twostate.sdfg'))
sdfg.apply_transformations_repeated(MapFusion)
sdfg.save(os.path.join('_dacegraphs', 'after-twostate.sdfg'))

A = np.random.rand(10, 20).astype(np.float32)
B = np.random.rand(10, 20).astype(np.float32)
out = np.zeros(shape=1, dtype=np.float32)
sdfg(A=A, B=B, out=out)

# NOTE: `2 * np.sum(A * A + B)` is the expected result of applying `fusion()` twice.
diff = abs(2 * np.sum(A * A + B) - out)
print('Difference:', diff)
assert diff <= 1e-3


def test_multiple_fusions():
sdfg = multiple_fusions.to_sdfg()
num_nodes_before = len([node for state in sdfg.nodes() for node in state.nodes()])
Expand Down Expand Up @@ -312,6 +344,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3

if __name__ == '__main__':
test_fusion_simple()
test_fusion_twostate()
test_multiple_fusions()
test_fusion_chain()
test_fusion_with_transient()
Expand Down

0 comments on commit c0906da

Please sign in to comment.