diff --git a/tests/mr/device/mr_multithreaded_tests.cpp b/tests/mr/device/mr_multithreaded_tests.cpp index 38c34d93f..6d6d8edc2 100644 --- a/tests/mr/device/mr_multithreaded_tests.cpp +++ b/tests/mr/device/mr_multithreaded_tests.cpp @@ -179,6 +179,7 @@ void allocate_loop(rmm::mr::device_memory_resource* mr, std::size_t num_allocations, std::list& allocations, std::mutex& mtx, + cudaEvent_t& event, rmm::cuda_stream_view stream) { constexpr std::size_t max_size{1_MiB}; @@ -191,6 +192,7 @@ void allocate_loop(rmm::mr::device_memory_resource* mr, void* ptr = mr->allocate(size, stream); { std::lock_guard lock(mtx); + RMM_CUDA_TRY(cudaEventRecord(event, stream.value())); allocations.emplace_back(ptr, size); } } @@ -200,12 +202,14 @@ void deallocate_loop(rmm::mr::device_memory_resource* mr, std::size_t num_allocations, std::list& allocations, std::mutex& mtx, + cudaEvent_t& event, rmm::cuda_stream_view stream) { for (std::size_t i = 0; i < num_allocations;) { std::lock_guard lock(mtx); if (allocations.empty()) { continue; } i++; + RMM_CUDA_TRY(cudaStreamWaitEvent(stream.value(), event)); allocation alloc = allocations.front(); allocations.pop_front(); mr->deallocate(alloc.ptr, alloc.size, stream); @@ -220,15 +224,30 @@ void test_allocate_free_different_threads(rmm::mr::device_memory_resource* mr, std::mutex mtx; std::list allocations; - - std::thread producer( - allocate_loop, mr, num_allocations, std::ref(allocations), std::ref(mtx), streamA); - - std::thread consumer( - deallocate_loop, mr, num_allocations, std::ref(allocations), std::ref(mtx), streamB); + cudaEvent_t event; + + RMM_CUDA_TRY(cudaEventCreate(&event)); + + std::thread producer(allocate_loop, + mr, + num_allocations, + std::ref(allocations), + std::ref(mtx), + std::ref(event), + streamA); + + std::thread consumer(deallocate_loop, + mr, + num_allocations, + std::ref(allocations), + std::ref(mtx), + std::ref(event), + streamB); producer.join(); consumer.join(); + + RMM_CUDA_TRY(cudaEventDestroy(event)); } TEST_P(mr_test_mt, AllocFreeDifferentThreadsDefaultStream)