Skip to content

Commit

Permalink
Add support for custom NDArray memory management
Browse files Browse the repository at this point in the history
Credit to @icemelon9 and @wweic
  • Loading branch information
jroesch committed Apr 30, 2019
1 parent d850073 commit c8ea99b
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 9 deletions.
50 changes: 50 additions & 0 deletions src/runtime/memory_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/memory_manager.cc
* \brief Allocate and manage memory for the runtime.
*/
#include <utility>
#include "memory_manager.h"
#include "naive_allocator.h"
#include "pooled_allocator.h"

namespace tvm {
namespace runtime {

MemoryManager* MemoryManager::Global() {
static MemoryManager memory_manager;
return &memory_manager;
}

Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
std::lock_guard<std::mutex> lock(mu_);
if (allocators_.find(ctx) == allocators_.end()) {
// LOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "("
// << ctx.device_id << ")";
std::unique_ptr<Allocator> alloc(new NaiveAllocator(ctx));
allocators_.emplace(ctx, std::move(alloc));
}
return allocators_.at(ctx).get();
}

} // namespace runtime
} // namespace tvm
50 changes: 50 additions & 0 deletions src/runtime/naive_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/memory_manager.cc
* \brief Allocate and manage memory for the runtime.
*/
#include <utility>
#include "memory_manager.h"
#include "naive_allocator.h"
#include "pooled_allocator.h"

namespace tvm {
namespace runtime {

MemoryManager* MemoryManager::Global() {
static MemoryManager memory_manager;
return &memory_manager;
}

Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
std::lock_guard<std::mutex> lock(mu_);
if (allocators_.find(ctx) == allocators_.end()) {
// LOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "("
// << ctx.device_id << ")";
std::unique_ptr<Allocator> alloc(new NaiveAllocator(ctx));
allocators_.emplace(ctx, std::move(alloc));
}
return allocators_.at(ctx).get();
}

} // namespace runtime
} // namespace tvm
39 changes: 30 additions & 9 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -26,6 +26,8 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>

#include "memory_manager.h"
#include "runtime_base.h"

// deleter for arrays used by DLPack exporter
Expand Down Expand Up @@ -76,15 +78,27 @@ struct NDArray::Internal {
}
delete ptr;
}

static void BufferDeleter(NDArray::Container* ptr) {
CHECK(ptr->buffer_ != nullptr);
MemoryManager::Global()->GetAllocator(ptr->buffer_->ctx)->
Free(*(ptr->buffer_));
delete ptr->buffer_;
delete ptr;
}
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
DLContext ctx, bool with_allocator = false) {
VerifyDataType(dtype);
// critical zone
NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter;
if (with_allocator) {
data->deleter = BufferDeleter;
} else {
data->deleter = DefaultDeleter;
}
NDArray ret(data);
ret.data_ = data;
// RAII now in effect
Expand Down Expand Up @@ -142,14 +156,21 @@ DLManagedTensor* NDArray::ToDLPack() const {

NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
DLContext ctx,
Allocator* allocator) {
NDArray ret = Internal::Create(shape, dtype, ctx, (allocator != nullptr));
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
if (allocator == nullptr) {
ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
} else {
ret.data_->buffer_ = new Buffer;
*ret.data_->buffer_ = allocator->Alloc(size, alignment, ret->dtype);
ret.data_->dl_tensor.data = ret.data_->buffer_->data;
}
return ret;
}

Expand Down
101 changes: 101 additions & 0 deletions src/runtime/pooled_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file runtime/pooled_allocator.h
*/
#ifndef TVM_RUNTIME_POOLED_ALLOCATOR_H_
#define TVM_RUNTIME_POOLED_ALLOCATOR_H_

#include <tvm/runtime/device_api.h>
#include <atomic>
#include <mutex>
#include <unordered_map>
#include <vector>

#include "memory_manager.h"

namespace tvm {
namespace runtime {

class PooledAllocator final : public Allocator {
public:
static constexpr size_t kDefaultPageSize = 4096;

explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
: Allocator(ctx), page_size_(page_size), used_memory_(0) {}

~PooledAllocator() { ReleaseAll(); }

Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) override {
std::lock_guard<std::mutex> lock(mu_);
size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_;
auto&& it = memory_pool_.find(size);
if (it != memory_pool_.end() && !it->second.empty()) {
auto&& pool = it->second;
auto ret = pool.back();
pool.pop_back();
return ret;
}
Buffer buf;
buf.ctx = ctx_;
buf.size = size;
buf.data = DeviceAPI::Get(ctx_)->AllocDataSpace(ctx_, size, alignment, type_hint);
used_memory_.fetch_add(size, std::memory_order_relaxed);
LOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B";
return buf;
}

void Free(const Buffer& buffer) override {
std::lock_guard<std::mutex> lock(mu_);
if (memory_pool_.find(buffer.size) == memory_pool_.end()) {
memory_pool_.emplace(buffer.size, std::vector<Buffer>{});
}
memory_pool_.at(buffer.size).push_back(buffer);
LOG(INFO) << "reclaim buffer " << buffer.size;
}

size_t UsedMemory() override { return used_memory_.load(std::memory_order_relaxed); }

private:
void ReleaseAll() {
std::lock_guard<std::mutex> lock(mu_);
for (auto const& it : memory_pool_) {
auto const& pool = it.second;
for (auto const& buf : pool) {
DeviceAPI::Get(buf.ctx)->FreeDataSpace(buf.ctx, buf.data);
}
}
memory_pool_.clear();
used_memory_ = 0;
LOG(INFO) << "release all buffers";
}

private:
size_t page_size_;
std::atomic<size_t> used_memory_;
std::unordered_map<size_t, std::vector<Buffer>> memory_pool_;
std::mutex mu_;
};

} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_POOLED_ALLOCATOR_H_

0 comments on commit c8ea99b

Please sign in to comment.