Skip to content

Commit

Permalink
Fix hipGraph memory leaks (#445)
Browse files Browse the repository at this point in the history
Currently, the hipGraph tests overwrite the graph instance created by hipStreamStartCapture.
This change moves the hipGraph helper functions into a class that encapsulates the
graph instance, ensuring it doesn't get overwritten.
  • Loading branch information
umfranzw authored Dec 12, 2024
1 parent 5fc34f4 commit b92c92d
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 659 deletions.
23 changes: 7 additions & 16 deletions test/hipcub/test_hipcub_device_adjacent_difference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,9 @@ TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy)
HIP_CHECK(
test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));

hipGraph_t graph;
if(TestFixture::params::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::params::use_graphs)
gHelper.startStreamCapture(stream);

HIP_CHECK(dispatch_adjacent_difference(left_constant,
copy_constant,
Expand All @@ -231,11 +229,8 @@ TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy)
op,
stream));

hipGraphExec_t graph_instance;
if(TestFixture::params::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::params::use_graphs)
gHelper.createAndLaunchGraph(stream);

std::vector<output_type> output(size);
HIP_CHECK(hipMemcpy(output.data(),
Expand All @@ -253,16 +248,12 @@ TYPED_TEST(HipcubDeviceAdjacentDifference, SubtractLeftCopy)
ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(output, expected));

if(TestFixture::params::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();
}
}

if(TestFixture::params::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}

// Params for tests
Expand Down Expand Up @@ -461,4 +452,4 @@ TYPED_TEST(HipcubDeviceAdjacentDifferenceLargeTests, LargeIndicesAndOpOnce)
HIP_CHECK(hipFree(d_flags));
}
}
}
}
42 changes: 12 additions & 30 deletions test/hipcub/test_hipcub_device_for.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,15 @@ TYPED_TEST(HipcubDeviceForTests, ForEach)
std::vector<T> expected(input);
std::for_each(expected.begin(), expected.end(), plus<T>());

hipGraph_t graph;
if(TestFixture::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::use_graphs)
gHelper.startStreamCapture(stream);

// Run
HIP_CHECK(hipcub::ForEach(d_input, d_input + size, plus<T>(), stream));

hipGraphExec_t graph_instance;
if(TestFixture::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::use_graphs)
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipGetLastError());
HIP_CHECK(hipDeviceSynchronize());
Expand All @@ -144,18 +139,14 @@ TYPED_TEST(HipcubDeviceForTests, ForEach)
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));

if(TestFixture::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();

HIP_CHECK(hipFree(d_input));
}
}

if(TestFixture::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}

template<class T>
Expand Down Expand Up @@ -283,20 +274,15 @@ TYPED_TEST(HipcubDeviceForTests, ForEachN)
std::vector<T> expected(input);
std::for_each(expected.begin(), expected.begin() + n, plus<T>());

hipGraph_t graph;
if(TestFixture::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::use_graphs)
gHelper.startStreamCapture(stream);

// Run
HIP_CHECK(hipcub::ForEachN(d_input, n, plus<T>(), stream));

hipGraphExec_t graph_instance;
if(TestFixture::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::use_graphs)
gHelper.createAndLaunchGraph(stream);

HIP_CHECK(hipGetLastError());
HIP_CHECK(hipDeviceSynchronize());
Expand All @@ -312,18 +298,14 @@ TYPED_TEST(HipcubDeviceForTests, ForEachN)
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));

if(TestFixture::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();

HIP_CHECK(hipFree(d_input));
}
}

if(TestFixture::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}

template<class T>
Expand Down
84 changes: 24 additions & 60 deletions test/hipcub/test_hipcub_device_histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,9 @@ TYPED_TEST(HipcubDeviceHistogramEven, Even)
void * d_temporary_storage;
HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));

hipGraph_t graph;
if(TestFixture::params::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::params::use_graphs)
gHelper.startStreamCapture(stream);

if(rows == 1)
{
Expand Down Expand Up @@ -306,11 +304,8 @@ TYPED_TEST(HipcubDeviceHistogramEven, Even)
stream));
}

hipGraphExec_t graph_instance;
if(TestFixture::params::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::params::use_graphs)
gHelper.createAndLaunchGraph(stream);

std::vector<counter_type> histogram(bins);
HIP_CHECK(
Expand All @@ -331,16 +326,12 @@ TYPED_TEST(HipcubDeviceHistogramEven, Even)
}

if(TestFixture::params::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();
}
}

if(TestFixture::params::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}

// Test HistogramEven overflow
Expand Down Expand Up @@ -625,11 +616,9 @@ TYPED_TEST(HipcubDeviceHistogramRange, Range)
void * d_temporary_storage;
HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));

hipGraph_t graph;
if(TestFixture::params::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::params::use_graphs)
gHelper.startStreamCapture(stream);

if(rows == 1)
{
Expand All @@ -656,11 +645,8 @@ TYPED_TEST(HipcubDeviceHistogramRange, Range)
stream));
}

hipGraphExec_t graph_instance;
if(TestFixture::params::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::params::use_graphs)
gHelper.createAndLaunchGraph(stream);

std::vector<counter_type> histogram(bins);
HIP_CHECK(
Expand All @@ -682,15 +668,11 @@ TYPED_TEST(HipcubDeviceHistogramRange, Range)
}

if(TestFixture::params::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();
}
}
if(TestFixture::params::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}

template<class SampleType,
Expand Down Expand Up @@ -936,11 +918,9 @@ TYPED_TEST(HipcubDeviceHistogramMultiEven, MultiEven)
void * d_temporary_storage;
HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));

hipGraph_t graph;
if(TestFixture::params::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::params::use_graphs)
gHelper.startStreamCapture(stream);

if(rows == 1)
{
Expand Down Expand Up @@ -971,11 +951,8 @@ TYPED_TEST(HipcubDeviceHistogramMultiEven, MultiEven)
stream)));
}

hipGraphExec_t graph_instance;
if(TestFixture::params::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::params::use_graphs)
gHelper.createAndLaunchGraph(stream);

std::vector<counter_type> histogram[active_channels];
for(unsigned int channel = 0; channel < active_channels; channel++)
Expand Down Expand Up @@ -1005,16 +982,12 @@ TYPED_TEST(HipcubDeviceHistogramMultiEven, MultiEven)
}

if(TestFixture::params::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();
}
}

if(TestFixture::params::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}

template<class SampleType,
Expand Down Expand Up @@ -1278,11 +1251,9 @@ TYPED_TEST(HipcubDeviceHistogramMultiRange, MultiRange)
void * d_temporary_storage;
HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));

hipGraph_t graph;
if(TestFixture::params::use_graphs)
{
graph = test_utils::createGraphHelper(stream);
}
test_utils::GraphHelper gHelper;
if (TestFixture::params::use_graphs)
gHelper.startStreamCapture(stream);

if(rows == 1)
{
Expand Down Expand Up @@ -1311,11 +1282,8 @@ TYPED_TEST(HipcubDeviceHistogramMultiRange, MultiRange)
stream)));
}

hipGraphExec_t graph_instance;
if(TestFixture::params::use_graphs)
{
graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true);
}
if (TestFixture::params::use_graphs)
gHelper.createAndLaunchGraph(stream);

std::vector<counter_type> histogram[active_channels];
for(unsigned int channel = 0; channel < active_channels; channel++)
Expand Down Expand Up @@ -1346,14 +1314,10 @@ TYPED_TEST(HipcubDeviceHistogramMultiRange, MultiRange)
}

if(TestFixture::params::use_graphs)
{
test_utils::cleanupGraphHelper(graph, graph_instance);
}
gHelper.cleanupGraphHelper();
}
}

if(TestFixture::params::use_graphs)
{
HIP_CHECK(hipStreamDestroy(stream));
}
}
Loading

0 comments on commit b92c92d

Please sign in to comment.