diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 7a95b4b0a3fb..9e35290fabd7 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -54,7 +54,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { this->workspace_size = workspace->shape[0]; } else { // Fallback to thrust TLS caching allocator if workspace is not provided. - thrust_pool_ = thrust::mr::tls_disjoint_pool( + thrust_pool_ = &thrust::mr::tls_disjoint_pool( thrust::mr::get_global_resource(), thrust::mr::get_global_resource()); } @@ -67,20 +67,20 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { << " bytes."; return result; } - return thrust_pool_.do_allocate(bytes, alignment).get(); + return thrust_pool_->do_allocate(bytes, alignment).get(); } void do_deallocate(void* p, size_t bytes, size_t alignment) override { if (workspace != nullptr) { // No-op } else { - thrust_pool_.do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment); + thrust_pool_->do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment); } } thrust::mr::disjoint_unsynchronized_pool_resource - thrust_pool_; + thrust::mr::new_delete_resource>* thrust_pool_ = + nullptr; void* workspace = nullptr; size_t workspace_size = 0;