From b92c92d6d54642b6053de40c6e4fefed86e1b416 Mon Sep 17 00:00:00 2001 From: Wayne Franz Date: Thu, 12 Dec 2024 16:42:35 -0500 Subject: [PATCH] Fix hipGraph memory leaks (#445) Currently, the hipGraph tests overwrite the graph instance created by hipStreamStartCapture. This change moves the hipGraph helper functions into a class that encapsulates the graph instance, ensuring it doesn't get overwritten. --- ...test_hipcub_device_adjacent_difference.cpp | 23 +-- test/hipcub/test_hipcub_device_for.cpp | 42 ++--- test/hipcub/test_hipcub_device_histogram.cpp | 84 +++------- test/hipcub/test_hipcub_device_merge_sort.cpp | 148 +++++------------- test/hipcub/test_hipcub_device_partition.cpp | 63 +++----- test/hipcub/test_hipcub_device_radix_sort.hpp | 84 +++------- test/hipcub/test_hipcub_device_reduce.cpp | 84 +++------- .../test_hipcub_device_reduce_by_key.cpp | 21 +-- test/hipcub/test_hipcub_device_scan.cpp | 105 ++++--------- .../test_hipcub_device_segmented_reduce.cpp | 84 +++------- test/hipcub/test_hipcub_device_select.cpp | 105 ++++--------- test/hipcub/test_hipcub_device_spmv.cpp | 17 +- test/hipcub/test_utils_hipgraphs.hpp | 105 +++++++------ 13 files changed, 306 insertions(+), 659 deletions(-) diff --git a/test/hipcub/test_hipcub_device_adjacent_difference.cpp b/test/hipcub/test_hipcub_device_adjacent_difference.cpp index 865bdba5..cfb545ac 100644 --- a/test/hipcub/test_hipcub_device_adjacent_difference.cpp +++ b/test/hipcub/test_hipcub_device_adjacent_difference.cpp @@ -215,11 +215,9 @@ TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(dispatch_adjacent_difference(left_constant, copy_constant, @@ -231,11 +229,8 @@ TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy) op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); std::vector output(size); HIP_CHECK(hipMemcpy(output.data(), @@ -253,16 +248,12 @@ TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy) ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(output, expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } // Params for tests @@ -461,4 +452,4 @@ TYPED_TEST(HipcubDeviceAdjacentDifferenceLargeTests, LargeIndicesAndOpOnce) HIP_CHECK(hipFree(d_flags)); } } -} \ No newline at end of file +} diff --git a/test/hipcub/test_hipcub_device_for.cpp b/test/hipcub/test_hipcub_device_for.cpp index 5cd00e81..afee9abf 100644 --- a/test/hipcub/test_hipcub_device_for.cpp +++ b/test/hipcub/test_hipcub_device_for.cpp @@ -115,20 +115,15 @@ TYPED_TEST(HipcubDeviceForTests, ForEach) std::vector expected(input); std::for_each(expected.begin(), expected.end(), plus()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::ForEach(d_input, d_input + size, plus(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -144,18 +139,14 @@ TYPED_TEST(HipcubDeviceForTests, ForEach) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); } } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template @@ -283,20 +274,15 @@ TYPED_TEST(HipcubDeviceForTests, ForEachN) std::vector expected(input); std::for_each(expected.begin(), expected.begin() + n, plus()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::ForEachN(d_input, n, plus(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -312,18 +298,14 @@ TYPED_TEST(HipcubDeviceForTests, ForEachN) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); } } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template diff --git a/test/hipcub/test_hipcub_device_histogram.cpp b/test/hipcub/test_hipcub_device_histogram.cpp index 3754214a..29c9f9a2 100644 --- a/test/hipcub/test_hipcub_device_histogram.cpp +++ b/test/hipcub/test_hipcub_device_histogram.cpp @@ -273,11 +273,9 @@ TYPED_TEST(HipcubDeviceHistogramEven, Even) void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); if(rows == 1) { @@ -306,11 +304,8 @@ TYPED_TEST(HipcubDeviceHistogramEven, Even) stream)); } - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); std::vector histogram(bins); HIP_CHECK( @@ -331,16 +326,12 @@ TYPED_TEST(HipcubDeviceHistogramEven, Even) } if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } // Test HistogramEven overflow @@ -625,11 +616,9 @@ TYPED_TEST(HipcubDeviceHistogramRange, Range) void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); if(rows == 1) { @@ -656,11 +645,8 @@ TYPED_TEST(HipcubDeviceHistogramRange, Range) stream)); } - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); std::vector histogram(bins); HIP_CHECK( @@ -682,15 +668,11 @@ TYPED_TEST(HipcubDeviceHistogramRange, Range) } if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template histogram[active_channels]; for(unsigned int channel = 0; channel < active_channels; channel++) @@ -1005,16 +982,12 @@ TYPED_TEST(HipcubDeviceHistogramMultiEven, MultiEven) } if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template histogram[active_channels]; for(unsigned int channel = 0; channel < active_channels; channel++) @@ -1346,14 +1314,10 @@ TYPED_TEST(HipcubDeviceHistogramMultiRange, MultiRange) } if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } diff --git a/test/hipcub/test_hipcub_device_merge_sort.cpp b/test/hipcub/test_hipcub_device_merge_sort.cpp index fc573108..881fd48d 100644 --- a/test/hipcub/test_hipcub_device_merge_sort.cpp +++ b/test/hipcub/test_hipcub_device_merge_sort.cpp @@ -129,11 +129,9 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeys) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::SortKeys(d_temporary_storage, temporary_storage_bytes, @@ -142,11 +140,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeys) compare_function(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -164,16 +159,12 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeys) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(is_sorted_result, true)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceMergeSort, SortKeysCopy) @@ -237,11 +228,9 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeysCopy) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::SortKeysCopy(d_temporary_storage, temporary_storage_bytes, @@ -250,11 +239,9 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeysCopy) size, compare_function(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -273,16 +260,12 @@ TYPED_TEST(HipcubDeviceMergeSort, SortKeysCopy) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(is_sorted_result, true)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceMergeSort, StableSortKeys) @@ -342,11 +325,9 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeys) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::SortKeys(d_temporary_storage, temporary_storage_bytes, @@ -355,11 +336,8 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeys) compare_function(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -378,16 +356,12 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeys) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceMergeSort, StableSortKeysCopy) @@ -449,11 +423,9 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeysCopy) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::StableSortKeysCopy(d_temporary_storage, temporary_storage_bytes, @@ -463,11 +435,8 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeysCopy) compare_function(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -487,16 +456,12 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortKeysCopy) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceMergeSort, SortPairs) @@ -587,11 +552,9 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairs) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::SortPairs(d_temporary_storage, temporary_storage_bytes, @@ -601,11 +564,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairs) compare_op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -636,16 +596,12 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairs) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceMergeSort, SortPairsCopy) @@ -752,11 +708,9 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairsCopy) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::SortPairsCopy(d_temporary_storage, temporary_storage_bytes, @@ -768,11 +722,8 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairsCopy) compare_op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -797,16 +748,12 @@ TYPED_TEST(HipcubDeviceMergeSort, SortPairsCopy) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceMergeSort, StableSortPairs) @@ -896,11 +843,9 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortPairs) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceMergeSort::StableSortPairs(d_temporary_storage, temporary_storage_bytes, @@ -910,11 +855,8 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortPairs) compare_op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -945,14 +887,10 @@ TYPED_TEST(HipcubDeviceMergeSort, StableSortPairs) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } diff --git a/test/hipcub/test_hipcub_device_partition.cpp b/test/hipcub/test_hipcub_device_partition.cpp index 53764a70..4a66d58d 100644 --- a/test/hipcub/test_hipcub_device_partition.cpp +++ b/test/hipcub/test_hipcub_device_partition.cpp @@ -153,11 +153,9 @@ TYPED_TEST(HipcubDevicePartitionTests, Flagged) HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DevicePartition::Flagged( @@ -170,11 +168,8 @@ TYPED_TEST(HipcubDevicePartitionTests, Flagged) input.size(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipDeviceSynchronize()); @@ -208,9 +203,7 @@ TYPED_TEST(HipcubDevicePartitionTests, Flagged) expected_rejected.size())); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_flags)); @@ -221,9 +214,7 @@ TYPED_TEST(HipcubDevicePartitionTests, Flagged) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } // NOTE: The following lambdas cannot be inside the test because of nvcc @@ -330,11 +321,9 @@ TYPED_TEST(HipcubDevicePartitionTests, If) HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DevicePartition::If( @@ -347,11 +336,8 @@ TYPED_TEST(HipcubDevicePartitionTests, If) select_op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipDeviceSynchronize()); @@ -385,9 +371,7 @@ TYPED_TEST(HipcubDevicePartitionTests, If) expected_rejected.size())); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -397,9 +381,7 @@ TYPED_TEST(HipcubDevicePartitionTests, If) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } namespace @@ -513,11 +495,9 @@ TYPED_TEST(HipcubDevicePartitionTests, IfThreeWay) void* d_temp_storage = nullptr; HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DevicePartition::If( @@ -533,11 +513,8 @@ TYPED_TEST(HipcubDevicePartitionTests, IfThreeWay) second_op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipDeviceSynchronize()); @@ -571,9 +548,7 @@ TYPED_TEST(HipcubDevicePartitionTests, IfThreeWay) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_first_output)); @@ -585,7 +560,5 @@ TYPED_TEST(HipcubDevicePartitionTests, IfThreeWay) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } diff --git a/test/hipcub/test_hipcub_device_radix_sort.hpp b/test/hipcub/test_hipcub_device_radix_sort.hpp index 8c4eaee3..68c8c3f4 100644 --- a/test/hipcub/test_hipcub_device_radix_sort.hpp +++ b/test/hipcub/test_hipcub_device_radix_sort.hpp @@ -288,11 +288,9 @@ void sort_keys() HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(invoke_sort_keys(d_temporary_storage, temporary_storage_bytes, @@ -303,11 +301,8 @@ void sort_keys() end_bit, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -323,16 +318,12 @@ void sort_keys() ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template @@ -560,11 +551,9 @@ void sort_pairs() HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(invoke_sort_pairs(d_temporary_storage, temporary_storage_bytes, @@ -577,11 +566,8 @@ void sort_pairs() end_bit, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -614,16 +600,12 @@ void sort_pairs() ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template @@ -798,11 +780,9 @@ void sort_keys_double_buffer() HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(invoke_sort_keys(d_temporary_storage, temporary_storage_bytes, @@ -812,11 +792,8 @@ void sort_keys_double_buffer() end_bit, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -832,16 +809,12 @@ void sort_keys_double_buffer() ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template @@ -1050,11 +1023,9 @@ void sort_pairs_double_buffer() HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(invoke_sort_pairs(d_temporary_storage, temporary_storage_bytes, @@ -1065,11 +1036,8 @@ void sort_pairs_double_buffer() end_bit, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -1102,16 +1070,12 @@ void sort_pairs_double_buffer() ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } inline void sort_keys_over_4g() diff --git a/test/hipcub/test_hipcub_device_reduce.cpp b/test/hipcub/test_hipcub_device_reduce.cpp index 1ed8dbca..f01cf839 100644 --- a/test/hipcub/test_hipcub_device_reduce.cpp +++ b/test/hipcub/test_hipcub_device_reduce.cpp @@ -156,11 +156,9 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceSum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run reduce_selector.reduce_sum(d_temp_storage, @@ -170,11 +168,8 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceSum) input.size(), stream); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -193,9 +188,7 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceSum) test_utils::precision::value * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -204,9 +197,7 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceSum) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceReduceTests, ReduceMinimum) @@ -277,11 +268,9 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceMinimum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DeviceReduce::Min(d_temp_storage, @@ -291,11 +280,8 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceMinimum) input.size(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -316,9 +302,7 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceMinimum) : std::max(test_utils::precision::value, test_utils::precision::value))); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -327,9 +311,7 @@ TYPED_TEST(HipcubDeviceReduceTests, ReduceMinimum) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } struct ArgMinDispatch @@ -445,11 +427,9 @@ void test_argminmax(typename TestFixture::input_type empty_value) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(function(d_temp_storage, @@ -459,11 +439,8 @@ void test_argminmax(typename TestFixture::input_type empty_value) input.size(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -484,16 +461,12 @@ void test_argminmax(typename TestFixture::input_type empty_value) test_utils::precision::value * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceReduceTests, ReduceArgMinimum) @@ -696,11 +669,9 @@ TYPED_TEST(HipcubDeviceReduceTests, TransformReduce) // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DeviceReduce::TransformReduce(d_temp_storage, @@ -713,11 +684,8 @@ TYPED_TEST(HipcubDeviceReduceTests, TransformReduce) init, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -735,9 +703,7 @@ TYPED_TEST(HipcubDeviceReduceTests, TransformReduce) test_utils::precision::value * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -746,9 +712,7 @@ TYPED_TEST(HipcubDeviceReduceTests, TransformReduce) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } // --------------------------------------------------------- diff --git a/test/hipcub/test_hipcub_device_reduce_by_key.cpp b/test/hipcub/test_hipcub_device_reduce_by_key.cpp index 78e8e7e4..67a60849 100644 --- a/test/hipcub/test_hipcub_device_reduce_by_key.cpp +++ b/test/hipcub/test_hipcub_device_reduce_by_key.cpp @@ -209,11 +209,9 @@ TYPED_TEST(HipcubDeviceReduceByKey, ReduceByKey) void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceReduce::ReduceByKey(d_temporary_storage, temporary_storage_bytes, @@ -226,11 +224,8 @@ TYPED_TEST(HipcubDeviceReduceByKey, ReduceByKey) size, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -276,14 +271,10 @@ TYPED_TEST(HipcubDeviceReduceByKey, ReduceByKey) * TestFixture::params::max_segment_length)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } diff --git a/test/hipcub/test_hipcub_device_scan.cpp b/test/hipcub/test_hipcub_device_scan.cpp index 53187f01..cc1bdd8c 100644 --- a/test/hipcub/test_hipcub_device_scan.cpp +++ b/test/hipcub/test_hipcub_device_scan.cpp @@ -258,20 +258,15 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScan) // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run call(d_temp_storage, temp_storage_size_bytes); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -298,9 +293,7 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScan) test_utils::assert_near(output, expected, single_op_precision * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); if(!inplace) @@ -312,9 +305,7 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScan) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceScanTests, InclusiveScanByKey) @@ -444,11 +435,9 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScanByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run if(std::is_same::value) @@ -475,11 +464,8 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScanByKey) stream)); } - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -496,9 +482,7 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScanByKey) test_utils::assert_near(output, expected, single_op_precision * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_keys)); HIP_CHECK(hipFree(d_input)); @@ -508,9 +492,7 @@ TYPED_TEST(HipcubDeviceScanTests, InclusiveScanByKey) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceScanTests, ExclusiveScan) @@ -668,20 +650,15 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScan) // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run call(d_temp_storage, temp_storage_size_bytes); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -708,9 +685,7 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScan) test_utils::assert_near(output, expected, single_op_precision * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); if(!inplace) @@ -722,9 +697,7 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScan) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanByKey) @@ -867,11 +840,9 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run if(std::is_same::value) @@ -899,11 +870,8 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanByKey) stream)); } - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -920,9 +888,7 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanByKey) test_utils::assert_near(output, expected, single_op_precision * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_keys)); HIP_CHECK(hipFree(d_input)); @@ -932,9 +898,7 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanByKey) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } // CUB does not support large indices in inclusive and exclusive scans @@ -1221,11 +1185,9 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanFuture) fill_initial_value<<<1, 1, 0, stream>>>(d_initial_value, initial_value); HIP_CHECK(hipGetLastError()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DeviceScan::ExclusiveScan(d_temp_storage, @@ -1237,11 +1199,8 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanFuture) input.size(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -1258,9 +1217,7 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanFuture) test_utils::assert_near(output, expected, single_op_precision * size)); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -1270,7 +1227,5 @@ TYPED_TEST(HipcubDeviceScanTests, ExclusiveScanFuture) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } diff --git a/test/hipcub/test_hipcub_device_segmented_reduce.cpp b/test/hipcub/test_hipcub_device_segmented_reduce.cpp index e85d1636..037816a3 100644 --- a/test/hipcub/test_hipcub_device_segmented_reduce.cpp +++ b/test/hipcub/test_hipcub_device_segmented_reduce.cpp @@ -193,11 +193,9 @@ TYPED_TEST(HipcubDeviceSegmentedReduceOp, Reduce) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - hipGraph_t graph; - if(TestFixture::params::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::params::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceSegmentedReduce::Reduce(d_temporary_storage, temporary_storage_bytes, @@ -210,11 +208,8 @@ TYPED_TEST(HipcubDeviceSegmentedReduceOp, Reduce) init, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::params::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::params::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipFree(d_temporary_storage)); @@ -232,16 +227,12 @@ TYPED_TEST(HipcubDeviceSegmentedReduceOp, Reduce) test_utils::assert_near(aggregates_output, aggregates_expected, precision)); if(TestFixture::params::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); } } if(TestFixture::params::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } template get_discontinuity_probabilities() @@ -770,11 +743,9 @@ TYPED_TEST(HipcubDeviceSelectTests, Unique) test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // Run HIP_CHECK(hipcub::DeviceSelect::Unique(d_temp_storage, @@ -785,11 +756,8 @@ TYPED_TEST(HipcubDeviceSelectTests, Unique) input.size(), stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipDeviceSynchronize()); @@ -813,9 +781,7 @@ TYPED_TEST(HipcubDeviceSelectTests, Unique) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -826,9 +792,7 @@ TYPED_TEST(HipcubDeviceSelectTests, Unique) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TEST(HipcubDeviceSelectTests, UniqueDiscardOutputIterator) @@ -1188,11 +1152,9 @@ TYPED_TEST(HipcubDeviceUniqueByKeyTests, UniqueByKey) HIP_CHECK( test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); // run HIP_CHECK(hipcub::DeviceSelect::UniqueByKey(d_temp_storage, @@ -1206,11 +1168,8 @@ TYPED_TEST(HipcubDeviceUniqueByKeyTests, UniqueByKey) equality_op, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); // Check if number of selected value is as expected selected_count_type selected_count_output = 0; @@ -1242,9 +1201,7 @@ TYPED_TEST(HipcubDeviceUniqueByKeyTests, UniqueByKey) test_utils::assert_eq(output_values, expected_values, expected_values.size())); if(TestFixture::use_graphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - } + gHelper.cleanupGraphHelper(); HIP_CHECK(hipFree(d_keys_input)); HIP_CHECK(hipFree(d_values_input)); @@ -1257,9 +1214,7 @@ TYPED_TEST(HipcubDeviceUniqueByKeyTests, UniqueByKey) } if(TestFixture::use_graphs) - { HIP_CHECK(hipStreamDestroy(stream)); - } } TEST(HipcubDeviceUniqueByKeyTests, LargeIndicesUniqueByKey) diff --git a/test/hipcub/test_hipcub_device_spmv.cpp b/test/hipcub/test_hipcub_device_spmv.cpp index 5e9b157b..af7956a0 100644 --- a/test/hipcub/test_hipcub_device_spmv.cpp +++ b/test/hipcub/test_hipcub_device_spmv.cpp @@ -217,11 +217,9 @@ TYPED_TEST(HipcubDeviceSpmvTests, Spmv) HIP_CHECK(g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - if(TestFixture::use_graphs) - { - graph = test_utils::createGraphHelper(stream); - } + test_utils::GraphHelper gHelper; + if (TestFixture::use_graphs) + gHelper.startStreamCapture(stream); HIP_CHECK(hipcub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, @@ -235,11 +233,8 @@ TYPED_TEST(HipcubDeviceSpmvTests, Spmv) params.num_nonzeros, stream)); - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - { - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - } + if (TestFixture::use_graphs) + gHelper.createAndLaunchGraph(stream); HIP_CHECK(hipMemcpy(vector_y_in, params.d_vector_y, sizeof(T) * params.num_rows, hipMemcpyDeviceToHost)); @@ -257,7 +252,7 @@ TYPED_TEST(HipcubDeviceSpmvTests, Spmv) if(TestFixture::use_graphs) { - test_utils::cleanupGraphHelper(graph, graph_instance); + gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); } } diff --git a/test/hipcub/test_utils_hipgraphs.hpp b/test/hipcub/test_utils_hipgraphs.hpp index 2c416f64..f3e7e193 100644 --- a/test/hipcub/test_utils_hipgraphs.hpp +++ b/test/hipcub/test_utils_hipgraphs.hpp @@ -40,53 +40,64 @@ // Note: graphs will not work on the default stream. namespace test_utils { - -inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture = true) -{ - // Create a new graph - hipGraph_t graph; - HIP_CHECK(hipGraphCreate(&graph, 0)); - - // Optionally begin stream capture - if(beginCapture) - { - HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); - } - - return graph; -} - -inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance) -{ - HIP_CHECK(hipGraphDestroy(graph)); - HIP_CHECK(hipGraphExecDestroy(instance)); -} - -inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, - hipStream_t& stream, - const bool launchGraph = false, - const bool sync = false) -{ - // End the capture - HIP_CHECK(hipStreamEndCapture(stream, &graph)); - - // Instantiate the graph - hipGraphExec_t instance; - HIP_CHECK(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); - - // Optionally launch the graph - if(launchGraph) - HIP_CHECK(hipGraphLaunch(instance, stream)); - - // Optionally synchronize the stream when we're done - if(sync) - { - HIP_CHECK(hipStreamSynchronize(stream)); - } - - return instance; -} - + class GraphHelper{ + private: + hipGraph_t graph; + hipGraphExec_t graph_instance; + public: + + inline void startStreamCapture(hipStream_t & stream) + { + HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + } + + inline void endStreamCapture(hipStream_t & stream) + { + HIP_CHECK(hipStreamEndCapture(stream, &graph)); + } + + inline void createAndLaunchGraph(hipStream_t & stream, const bool launchGraph=true, const bool sync=true) + { + // End current capture + endStreamCapture(stream); + + // Create the graph instance + HIP_CHECK(hipGraphInstantiate(&graph_instance, graph, nullptr, nullptr, 0)); + + // Optionally launch the graph + if (launchGraph) + HIP_CHECK(hipGraphLaunch(graph_instance, stream)); + + // Optionally synchronize the stream when we're done + if (sync) + HIP_CHECK(hipStreamSynchronize(stream)); + } + + inline void cleanupGraphHelper() + { + HIP_CHECK(hipGraphDestroy(this->graph)); + HIP_CHECK(hipGraphExecDestroy(this->graph_instance)); + } + + inline void resetGraphHelper(hipStream_t& stream, const bool beginCapture=true) + { + // Destroy the old graph and instance + cleanupGraphHelper(); + + // Re-start capture + if(beginCapture) + startStreamCapture(stream); + } + + inline void launchGraphHelper(hipStream_t& stream,const bool sync=false) + { + HIP_CHECK(hipGraphLaunch(this->graph_instance, stream)); + + // Optionally sync after the launch + if (sync) + HIP_CHECK(hipStreamSynchronize(stream)); + } + }; } // end namespace test_utils #undef HIP_CHECK