Skip to content

Commit

Permalink
[Thrust] Use pointer to tls pool to prevent creating new pool (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Apr 9, 2024
1 parent 0594994 commit a309b6b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource<void*> {
this->workspace_size = workspace->shape[0];
} else {
// Fallback to thrust TLS caching allocator if workspace is not provided.
thrust_pool_ = thrust::mr::tls_disjoint_pool(
thrust_pool_ = &thrust::mr::tls_disjoint_pool(
thrust::mr::get_global_resource<thrust::device_memory_resource>(),
thrust::mr::get_global_resource<thrust::mr::new_delete_resource>());
}
Expand All @@ -67,20 +67,20 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource<void*> {
<< " bytes.";
return result;
}
return thrust_pool_.do_allocate(bytes, alignment).get();
return thrust_pool_->do_allocate(bytes, alignment).get();
}

void do_deallocate(void* p, size_t bytes, size_t alignment) override {
if (workspace != nullptr) {
// No-op
} else {
thrust_pool_.do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment);
thrust_pool_->do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment);
}
}

thrust::mr::disjoint_unsynchronized_pool_resource<thrust::device_memory_resource,
thrust::mr::new_delete_resource>
thrust_pool_;
thrust::mr::new_delete_resource>* thrust_pool_ =
nullptr;

void* workspace = nullptr;
size_t workspace_size = 0;
Expand Down

0 comments on commit a309b6b

Please sign in to comment.