diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 844651b071..34bba101fb 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -438,6 +438,29 @@ def apply(self, _, sdfg: sd.SDFG): ####################################################### # Step 7: Wrap free tasklets and nested SDFGs with a GPU map + # Extend global_code_nodes with tasklets that write/read from an array + # Previous steps map all arrays to GPU storage, but only checks tasklets that write to/read from + # Scalars to be wrapped in a GPU Map + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.Tasklet): + if node in global_code_nodes[state]: + continue + if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel( + state.parent, state, node): + memlet_path_roots = set() + memlet_path_roots = memlet_path_roots.union( + [state.memlet_tree(e).root().edge.src for e in state.in_edges(node)] + ) + memlet_path_roots = memlet_path_roots.union( + [state.memlet_tree(e).root().edge.dst for e in state.out_edges(node)] + ) + gpu_accesses = [n.data for n in memlet_path_roots + if isinstance(n, nodes.AccessNode) and + sdfg.arrays[n.data].storage in gpu_storage] + if len(gpu_accesses) > 0: + global_code_nodes[state].append(node) + for state, gcodes in global_code_nodes.items(): for gcode in gcodes: if gcode.label in self.exclude_tasklets.split(','): diff --git a/tests/transformations/gpu_transform_test.py b/tests/transformations/gpu_transform_test.py index f6d299e630..62f95dcdb9 100644 --- a/tests/transformations/gpu_transform_test.py +++ b/tests/transformations/gpu_transform_test.py @@ -118,6 +118,35 @@ def write_subset_dynamic(A: dace.int32[20, 20], x: dace.int32[20], y: dace.int32 assert np.array_equal(ref, val) +@pytest.mark.parametrize(["transient", "scalar"], + [[False, False], [False, True], + [True, False], [True, True]]) +def test_free_tasklet(transient, scalar): + sdfg = dace.SDFG("assign") + + state = sdfg.add_state("main") + if scalar: + arr_name, arr = sdfg.add_scalar("A", dace.float32, transient=transient) + else: + arr_name, arr = sdfg.add_array("A", (4,), dace.float32, transient=transient) + + an = state.add_access(arr_name) + + t = state.add_tasklet("assign", {}, {"_out"}, "_out = 2.0") + state.add_edge(t, "_out", an, None, dace.memlet.Memlet("A" if scalar else "A[0]")) + + sdfg.validate() + + sdfg.apply_gpu_transformations( + validate = True, + validate_all = True, + permissive = True, + sequential_innermaps=True, + register_transients=False, + simplify=False + ) + + sdfg.validate() if __name__ == '__main__': test_toplevel_transient_lifetime() @@ -125,3 +154,6 @@ def write_subset_dynamic(A: dace.int32[20, 20], x: dace.int32[20], y: dace.int32 test_write_subset() test_write_full() test_write_subset_dynamic() + for scalar in [False, True]: + for transient in [False, True]: + test_free_tasklet(transient, scalar)