Skip to content

Commit

Permalink
#1683: Updated memory collection and report of dram and l1 using new …
Browse files Browse the repository at this point in the history
…metal APIs.
  • Loading branch information
tapspatel committed Jan 7, 2025
1 parent 028f711 commit 7757540
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 221 deletions.
6 changes: 6 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ void deallocateBuffers(Device device);

void dumpMemoryReport(Device device);

std::unordered_map<int, tt::runtime::detail::MemoryView>
getDramMemoryView(Device device);

std::unordered_map<int, tt::runtime::detail::MemoryView>
getL1MemoryView(Device device);

void wait(Event event);

void wait(Tensor tensor);
Expand Down
6 changes: 6 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ void deallocateBuffers(Device device);

void dumpMemoryReport(Device device);

std::unordered_map<int, tt::runtime::detail::MemoryView>
getDramMemoryView(Device device);

std::unordered_map<int, tt::runtime::detail::MemoryView>
getL1MemoryView(Device device);

void wait(Event event);

void wait(Tensor tensor);
Expand Down
3 changes: 3 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <cstdint>
#include <functional>
#include <unordered_map>
#include <vector>

#include "tt/runtime/types.h"
Expand All @@ -20,6 +21,8 @@ std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();
namespace detail {
void deallocateBuffers(Device device);
void dumpMemoryReport(Device device);
std::unordered_map<int, MemoryView> getDramMemoryView(Device device);
std::unordered_map<int, MemoryView> getL1MemoryView(Device device);
} // namespace detail

DeviceRuntime getCurrentRuntime();
Expand Down
12 changes: 12 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ struct RuntimeCheckedObjectImpl {
}
};

struct MemoryView {
std::uint64_t num_banks;
size_t bytes_allocatable_per_bank;
size_t bytes_allocated_per_bank;
size_t bytes_free_per_bank;
size_t total_bytes_allocatable; // bytes_allocatable_per_bank * num_banks
size_t total_bytes_allocated; // bytes_allocated_per_bank * num_banks
size_t total_bytes_free; // bytes_free_per_bank * num_banks
size_t largest_contiguous_bytes_free_per_bank;
std::vector<std::unordered_map<std::string, std::string>> block_table;
};

} // namespace detail

struct TensorDesc {
Expand Down
35 changes: 35 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,41 @@ void dumpMemoryReport(Device device) {

LOG_FATAL("runtime is not enabled");
}

std::unordered_map<int, tt::runtime::detail::MemoryView>
getDramMemoryView(Device device) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getDramMemoryView(device);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getDramMemoryView(device);
}
#endif

LOG_FATAL("runtime is not enabled");
}

std::unordered_map<int, tt::runtime::detail::MemoryView>
getL1MemoryView(Device device) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getL1MemoryView(device);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getL1MemoryView(device);
}
#endif

LOG_FATAL("runtime is not enabled");
}

} // namespace detail

DeviceRuntime getCurrentRuntime() {
Expand Down
48 changes: 48 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ static Tensor createNullTensor() {
return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal);
}

static tt::runtime::detail::MemoryView
createMemoryView(const tt::tt_metal::detail::MemoryView &memoryView) {
return tt::runtime::detail::MemoryView{
.num_banks = memoryView.num_banks,
.bytes_allocatable_per_bank = memoryView.bytes_allocatable_per_bank,
.bytes_allocated_per_bank = memoryView.bytes_allocated_per_bank,
.bytes_free_per_bank = memoryView.bytes_free_per_bank,
.total_bytes_allocatable = memoryView.total_bytes_allocatable,
.total_bytes_allocated = memoryView.total_bytes_allocated,
.total_bytes_free = memoryView.total_bytes_free,
.largest_contiguous_bytes_free_per_bank =
memoryView.largest_contiguous_bytes_free_per_bank,
.block_table = memoryView.blockTable};
}

Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride,
Expand Down Expand Up @@ -113,6 +128,39 @@ void dumpMemoryReport(Device deviceHandle) {
}
}

std::unordered_map<int, tt::runtime::detail::MemoryView>
getDramMemoryView(Device deviceHandle) {
std::unordered_map<int, tt::runtime::detail::MemoryView> memoryMap;

::tt::tt_metal::distributed::MeshDevice &meshDevice =
deviceHandle.as<::tt::tt_metal::distributed::MeshDevice>(
DeviceRuntime::TTMetal);

for (::tt::tt_metal::Device *device : meshDevice.get_devices()) {
auto dramMemoryView = ::tt::tt_metal::detail::GetDramMemoryView(device);
memoryMap[device->id()] = createMemoryView(dramMemoryView);
}

return memoryMap;
}

std::unordered_map<int, tt::runtime::detail::MemoryView>
getL1MemoryView(Device deviceHandle) {
std::unordered_map<int, tt::runtime::detail::MemoryView> memoryMap;

::tt::tt_metal::distributed::MeshDevice &meshDevice =
deviceHandle.as<::tt::tt_metal::distributed::MeshDevice>(
DeviceRuntime::TTMetal);

for (::tt::tt_metal::Device *device : meshDevice.get_devices()) {
auto l1MemoryView = ::tt::tt_metal::detail::GetL1MemoryView(device);
memoryMap[device->id()] = memoryMap[device->id()] =
createMemoryView(l1MemoryView);
}

return memoryMap;
}

void wait(Event event) {
Events events = event.as<Events>(DeviceRuntime::TTMetal);
for (auto e : events) {
Expand Down
45 changes: 45 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ static Tensor createNullTensor() {
return Tensor(nullptr, nullptr, DeviceRuntime::TTNN);
}

static tt::runtime::detail::MemoryView
createMemoryView(const tt::tt_metal::detail::MemoryView &memoryView) {
return tt::runtime::detail::MemoryView{
.num_banks = memoryView.num_banks,
.bytes_allocatable_per_bank = memoryView.bytes_allocatable_per_bank,
.bytes_allocated_per_bank = memoryView.bytes_allocated_per_bank,
.bytes_free_per_bank = memoryView.bytes_free_per_bank,
.total_bytes_allocatable = memoryView.total_bytes_allocatable,
.total_bytes_allocated = memoryView.total_bytes_allocated,
.total_bytes_free = memoryView.total_bytes_free,
.largest_contiguous_bytes_free_per_bank =
memoryView.largest_contiguous_bytes_free_per_bank,
.block_table = memoryView.blockTable};
}

static DeviceVariant getTargetDevice(::ttnn::MeshDevice &meshDevice) {
if (meshDevice.num_devices() == 1) {
return std::ref(*(meshDevice.get_device_index(0)));
Expand Down Expand Up @@ -222,6 +237,36 @@ void dumpMemoryReport(Device deviceHandle) {
}
}

std::unordered_map<int, tt::runtime::detail::MemoryView>
getDramMemoryView(Device deviceHandle) {
std::unordered_map<int, tt::runtime::detail::MemoryView> memoryMap;

::ttnn::MeshDevice &meshDevice =
deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN);

for (::ttnn::Device *device : meshDevice.get_devices()) {
auto dramMemoryView = ::tt::tt_metal::detail::GetDramMemoryView(device);
memoryMap[device->id()] = createMemoryView(dramMemoryView);
}

return memoryMap;
}

std::unordered_map<int, tt::runtime::detail::MemoryView>
getL1MemoryView(Device deviceHandle) {
std::unordered_map<int, tt::runtime::detail::MemoryView> memoryMap;

::ttnn::MeshDevice &meshDevice =
deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN);

for (::ttnn::Device *device : meshDevice.get_devices()) {
auto l1MemoryView = ::tt::tt_metal::detail::GetL1MemoryView(device);
memoryMap[device->id()] = createMemoryView(l1MemoryView);
}

return memoryMap;
}

void wait(Event event) {
// Nothing to do for ttnn runtime
LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN));
Expand Down
Loading

0 comments on commit 7757540

Please sign in to comment.