From a309b6b857e9abc6849193cc7fa80c015fee7969 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 8 Apr 2024 17:29:35 -0700 Subject: [PATCH] [Thrust] Use pointer to tls pool to prevent creating new pool (#16856) --- src/runtime/contrib/thrust/thrust.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 7a95b4b0a3fb..9e35290fabd7 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -54,7 +54,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { this->workspace_size = workspace->shape[0]; } else { // Fallback to thrust TLS caching allocator if workspace is not provided. - thrust_pool_ = thrust::mr::tls_disjoint_pool( + thrust_pool_ = &thrust::mr::tls_disjoint_pool( thrust::mr::get_global_resource(), thrust::mr::get_global_resource()); } @@ -67,20 +67,20 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { << " bytes."; return result; } - return thrust_pool_.do_allocate(bytes, alignment).get(); + return thrust_pool_->do_allocate(bytes, alignment).get(); } void do_deallocate(void* p, size_t bytes, size_t alignment) override { if (workspace != nullptr) { // No-op } else { - thrust_pool_.do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment); + thrust_pool_->do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment); } } thrust::mr::disjoint_unsynchronized_pool_resource - thrust_pool_; + thrust::mr::new_delete_resource>* thrust_pool_ = + nullptr; void* workspace = nullptr; size_t workspace_size = 0;