diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 63f3e4e5ffa2..10e8be35ef74 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -149,7 +149,11 @@ class SimpleObjAllocator : template class ArrayHandler { public: - using StorageType = typename std::aligned_union::type; + using StorageType = typename std::aligned_storage::type; + // for now only support elements that aligns with array header. + static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && + sizeof(ArrayType) % alignof(ElemType) == 0, + "element alignment constraint"); template static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { @@ -160,15 +164,15 @@ class SimpleObjAllocator : // In the case of an object pool, an allocator needs to create // a special chunk memory that hides reference to the allocator // and call allocator's release function in the deleter. - // NOTE2: Use inplace new to allocate // This is used to get rid of warning when deleting a virtual // class with non-virtual destructor. // We are fine here as we captured the right deleter during construction. // This is also the right way to get storage type for an object pool. - size_t factor = sizeof(ArrayType) / sizeof(ElemType); - num_elems = (num_elems + factor - 1) / factor; - StorageType* data = new StorageType[num_elems+1]; + size_t unit = sizeof(StorageType); + size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); + size_t num_storage_slots = (requested_size + unit - 1) / unit; + StorageType* data = new StorageType[num_storage_slots]; new (data) ArrayType(std::forward(args)...); return reinterpret_cast(data); }