diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 207a86a71c92..ef4b9b020f95 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -56,6 +56,8 @@ class VulkanThreadEntry { // the instance and device get destroyed. // The destruction need to be manually called // to ensure the destruction order. + + pool.reset(); streams_.clear(); for (const auto& kv : staging_buffers_) { if (!kv.second) { @@ -75,7 +77,7 @@ class VulkanThreadEntry { } TVMContext ctx; - WorkspacePool pool; + std::unique_ptr pool; VulkanStream* Stream(size_t device_id); VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); @@ -331,11 +333,11 @@ class VulkanDeviceAPI final : public DeviceAPI { } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { - return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); + return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(ctx, size); } void FreeWorkspace(TVMContext ctx, void* data) final { - VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); + VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data); } static const std::shared_ptr& Global() { @@ -999,7 +1001,8 @@ VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size } VulkanThreadEntry::VulkanThreadEntry() - : pool(static_cast(kDLVulkan), VulkanDeviceAPI::Global()) { + : pool(std::make_unique(static_cast(kDLVulkan), + VulkanDeviceAPI::Global())) { ctx.device_id = 0; ctx.device_type = static_cast(kDLVulkan); }