Skip to content

Commit

Permalink
Merge pull request apache#45 from octoml/toggle-memory-profiling
Browse files Browse the repository at this point in the history
Toggle memory profiling properly
  • Loading branch information
masahi authored Feb 15, 2024
2 parents 30f54c0 + 4c8d237 commit 5abceb5
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
14 changes: 13 additions & 1 deletion include/tvm/runtime/memory/memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Allocator {
/*! \brief The amount of memory currently allocated.
* \return The amount of memory currently allocated.
*/
virtual size_t UsedMemory() = 0;
virtual size_t UsedMemory() const = 0;

protected:
virtual Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
Expand Down Expand Up @@ -126,9 +126,21 @@ class MemoryManager {

static size_t UsedMemory(Device dev);

static void StartProfiling();
static void StopProfiling();

private:
MemoryManager() {}

void ForEachAllocator(std::function<void(Allocator*, AllocatorType, Device)> func) {
std::lock_guard<std::mutex> lock(mu_);
for (const auto& [device, allocators] : allocators_) {
for (const auto& [allocator_type, allocator] : allocators) {
func(allocator.get(), allocator_type, device);
}
}
}

protected:
std::mutex mu_;
std::unordered_map<Device, std::unordered_map<AllocatorType, std::unique_ptr<Allocator>>>
Expand Down
32 changes: 27 additions & 5 deletions src/runtime/memory/memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,26 @@ Allocator* MemoryManager::GetAllocator(Device dev, AllocatorType type) {

void MemoryManager::Clear() {
MemoryManager* m = MemoryManager::Global();
std::lock_guard<std::mutex> lock(m->mu_);
for (const auto& [device, allocators] : m->allocators_) {
for (const auto& [allocator_type, allocator] : allocators) {
allocator->Clear();
m->ForEachAllocator(
[](Allocator* allocator, AllocatorType alloc_type, Device dev) { allocator->Clear(); });
}

void MemoryManager::StartProfiling() {
MemoryManager* m = MemoryManager::Global();
m->ForEachAllocator([](Allocator* allocator, AllocatorType alloc_type, Device dev) {
if (auto* pooled_alloc = static_cast<PooledAllocator*>(allocator)) {
pooled_alloc->StartProfiling();
}
}
});
}

void MemoryManager::StopProfiling() {
MemoryManager* m = MemoryManager::Global();
m->ForEachAllocator([](Allocator* allocator, AllocatorType alloc_type, Device dev) {
if (auto* pooled_alloc = static_cast<PooledAllocator*>(allocator)) {
pooled_alloc->StopProfiling();
}
});
}

size_t MemoryManager::UsedMemory(Device dev) {
Expand Down Expand Up @@ -234,6 +248,14 @@ TVM_REGISTER_GLOBAL("vm.memory_manager.get_used_memory").set_body_typed([](Devic
return static_cast<int64_t>(MemoryManager::UsedMemory(dev));
});

TVM_REGISTER_GLOBAL("vm.memory_manager.start_profiling").set_body_typed([]() {
MemoryManager::StartProfiling();
});

TVM_REGISTER_GLOBAL("vm.memory_manager.stop_profiling").set_body_typed([]() {
MemoryManager::StopProfiling();
});

} // namespace memory
} // namespace runtime
} // namespace tvm
2 changes: 1 addition & 1 deletion src/runtime/memory/naive_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class NaiveAllocator final : public Allocator {
DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B";
}

size_t UsedMemory() override { return used_memory_.load(std::memory_order_relaxed); }
size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); }

private:
std::atomic<size_t> used_memory_;
Expand Down
12 changes: 9 additions & 3 deletions src/runtime/memory/pooled_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,18 @@ class PooledAllocator final : public Allocator {

void Clear() override { ReleaseAll(); }

size_t UsedMemory() override {
// HACK to disable eager recycling during memory profiling
recycle_eager = true;
size_t UsedMemory() const override {
return used_memory_.load(std::memory_order_relaxed);
}

void StartProfiling() {
recycle_eager = false;
}

void StopProfiling() {
recycle_eager = true;
}

private:
void ReleaseAll() {
std::lock_guard<std::recursive_mutex> lock(mu_);
Expand Down

0 comments on commit 5abceb5

Please sign in to comment.