Skip to content

Commit

Permalink
#0: Add WriteShard and ReadShard MeshBuffer APIs and resolve MeshBuff…
Browse files Browse the repository at this point in the history
…er dealloc issues

  - Add tests for reading and writing shards with Interleaved and Sharded configs
  - Add test for deallocation, verying addresses
  • Loading branch information
tt-asaigal committed Jan 24, 2025
1 parent c203bf6 commit 7bf17b9
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 49 deletions.
155 changes: 148 additions & 7 deletions tests/tt_metal/distributed/test_mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,44 @@
#include <tt-metalium/mesh_device_view.hpp>

#include "tests/tt_metal/tt_metal/common/multi_device_fixture.hpp"
#include "tt_metal/distributed/mesh_buffer.hpp"
#include "tt_metal/distributed/distributed.hpp"

namespace tt::tt_metal::distributed::test {
namespace {

using MeshBufferTest = T3000MultiDeviceFixture;

struct DeviceLocalShardedBufferTestConfig {
Shape2D num_pages_per_core;
Shape2D num_cores;
Shape2D page_shape;
uint32_t element_size = 1;
TensorMemoryLayout mem_config = TensorMemoryLayout::HEIGHT_SHARDED;
ShardOrientation shard_orientation = ShardOrientation::ROW_MAJOR;

Shape2D tensor2d_shape() {
return {num_pages_per_core.height() * num_cores.height(), num_pages_per_core.width() * num_cores.width()};
}

uint32_t num_pages() { return tensor2d_shape().height() * tensor2d_shape().width(); }

std::array<uint32_t, 2> shard_shape() {
return {num_pages_per_core.height() * page_shape.height(), num_pages_per_core.width() * page_shape.width()};
}

CoreRangeSet shard_grid() {
return CoreRangeSet(std::set<CoreRange>(
{CoreRange(CoreCoord(0, 0), CoreCoord(this->num_cores.height() - 1, this->num_cores.width() - 1))}));
}

uint32_t page_size() { return page_shape.height() * page_shape.width() * element_size; }

ShardSpecBuffer shard_parameters() {
return ShardSpecBuffer(
this->shard_grid(), this->shard_shape(), this->shard_orientation, this->page_shape, this->tensor2d_shape());
}
};

TEST_F(MeshBufferTest, ConfigValidation) {
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
Expand Down Expand Up @@ -78,22 +109,24 @@ TEST_F(MeshBufferTest, ReplicatedBufferInitialization) {
}

TEST_F(MeshBufferTest, Deallocation) {
// Verify that a buffer is deallocated on the MeshDevice when it goes
// out of scope on host. Create a buffer with a certain config in limited
// scope. Record its address. Create another buffer with the same config
// outside the scope. Verify that addresses match.
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};

const ReplicatedBufferConfig buffer_config{.size = 16 << 10};
std::shared_ptr<Buffer> buffer;
Allocator* allocator = nullptr;
uint32_t expected_address = 0;
{
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
buffer = replicated_buffer->get_device_buffer(Coordinate{0, 0});
allocator = buffer->allocator();
EXPECT_TRUE(allocator->allocated_buffers.contains(buffer.get()));
expected_address = replicated_buffer->address();
}
EXPECT_FALSE(allocator->allocated_buffers.contains(buffer.get()));
auto replicated_buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device_.get());
EXPECT_EQ(replicated_buffer->address(), expected_address);
}

TEST_F(MeshBufferTest, GetDeviceBuffer) {
Expand All @@ -112,5 +145,113 @@ TEST_F(MeshBufferTest, GetDeviceBuffer) {
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3}));
}

TEST_F(MeshBufferTest, InterleavedShardsReadWrite) {
constexpr uint32_t NUM_ITERS = 100;
uint32_t seed = tt::parse_env("TT_METAL_SEED", 0);
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32);

for (auto buffer_type : {BufferType::L1, BufferType::DRAM}) {
DeviceLocalBufferConfig per_device_buffer_config{
.page_size = single_tile_size,
.buffer_type = BufferType::L1,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};

std::uniform_int_distribution<int> gen_num_tiles(1, 1024);
std::mt19937 rng(seed);
for (int i = 0; i < NUM_ITERS; i++) {
uint32_t num_random_tiles = gen_num_tiles(rng);
ReplicatedBufferConfig global_buffer_config = {
.size = num_random_tiles * single_tile_size,
};

std::shared_ptr<MeshBuffer> buf =
MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());

std::vector<uint32_t> src_vec(num_random_tiles * single_tile_size / sizeof(uint32_t), 0);
std::iota(src_vec.begin(), src_vec.end(), i);
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
}
}
}

class DeviceLocalMeshBufferShardingTest
: public MeshBufferTest,
public testing::WithParamInterface<
std::tuple<std::array<uint32_t, 2>, std::array<uint32_t, 2>, TensorMemoryLayout>> {};

TEST_P(DeviceLocalMeshBufferShardingTest, ShardingTest) {
auto [num_pages_per_core, page_shape, shard_strategy] = GetParam();
CoreCoord core_grid_size = mesh_device_->compute_with_storage_grid_size();

DeviceLocalShardedBufferTestConfig test_config{
.num_pages_per_core = num_pages_per_core,
.num_cores = {core_grid_size.x, core_grid_size.y},
.page_shape = page_shape,
.mem_config = shard_strategy};
DeviceLocalBufferConfig per_device_buffer_config{
.page_size = test_config.page_size(),
.buffer_type = BufferType::L1,
.buffer_layout = test_config.mem_config,
.shard_parameters = test_config.shard_parameters(),
.bottom_up = false};

uint32_t buf_size = test_config.num_pages() * test_config.page_size();
ReplicatedBufferConfig global_buffer_config{
.size = buf_size,
};

auto buf = MeshBuffer::create(global_buffer_config, per_device_buffer_config, mesh_device_.get());
std::vector<uint32_t> src_vec(buf_size / sizeof(uint32_t), 0);
std::iota(src_vec.begin(), src_vec.end(), 0);

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
}

INSTANTIATE_TEST_SUITE_P(
DeviceLocalMeshBufferShardingTests,
DeviceLocalMeshBufferShardingTest,
::testing::Combine(
// num_pages_per_core
::testing::Values(
std::array<uint32_t, 2>{1, 1},
std::array<uint32_t, 2>{3, 137},
std::array<uint32_t, 2>{67, 4},
std::array<uint32_t, 2>{7, 11},
std::array<uint32_t, 2>{2, 2}),
// page_shape
::testing::Values(
std::array<uint32_t, 2>{1, 1024},
std::array<uint32_t, 2>{1, 2048},
std::array<uint32_t, 2>{1, 4},
std::array<uint32_t, 2>{32, 32},
std::array<uint32_t, 2>{1, 120}),
// shard_strategy
::testing::Values(
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED)));
} // namespace
} // namespace tt::tt_metal::distributed::test
22 changes: 22 additions & 0 deletions tt_metal/distributed/distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ void AddProgramToMeshWorkload(MeshWorkload& mesh_workload, Program& program, con

void EnqueueMeshWorkload(MeshCommandQueue& mesh_cq, MeshWorkload& mesh_workload, bool blocking);

template <typename DType>
void WriteShard(
MeshCommandQueue& mesh_cq,
std::shared_ptr<MeshBuffer>& mesh_buffer,
std::vector<DType>& src,
const Coordinate& coord,
bool blocking = false) {
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking);
}

template <typename DType>
void ReadShard(
MeshCommandQueue& mesh_cq,
std::vector<DType>& dst,
std::shared_ptr<MeshBuffer>& mesh_buffer,
const Coordinate& coord,
bool blocking = true) {
auto shard = mesh_buffer->get_device_buffer(coord);
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType));
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking);
}

void Finish(MeshCommandQueue& mesh_cq);

} // namespace distributed
Expand Down
79 changes: 46 additions & 33 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,37 +59,24 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}},
mesh_buffer_config);

// Rely on the MeshDevice allocator to provide the address for the entire mesh buffer.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device,
/*address=*/address.value_or(0),
device_local_size,
device_local_config.page_size,
device_local_config.buffer_type,
device_local_config.buffer_layout,
device_local_config.shard_parameters,
device_local_config.bottom_up);
std::shared_ptr<MeshBuffer> mesh_buffer;
if (!address.has_value()) {
*address = tt::tt_metal::detail::AllocateBuffer(backing_buffer.get());
auto* backing_buffer_ptr = backing_buffer.get();
// Rely on the MeshDevice allocator to provide the address for the entire mesh buffer.
// The address provided to the backing buffer is used as the address for the MeshBuffer object.
std::shared_ptr<Buffer> backing_buffer = Buffer::create(
mesh_device,
device_local_size,
device_local_config.page_size,
device_local_config.buffer_type,
device_local_config.buffer_layout,
device_local_config.shard_parameters,
device_local_config.bottom_up);

mesh_buffer = std::shared_ptr<MeshBuffer>(
new MeshBuffer(
mesh_buffer_config,
device_local_config,
*address,
device_local_size,
mesh_device,
std::move(backing_buffer)),
[backing_buffer_ptr](MeshBuffer*) { tt::tt_metal::detail::DeallocateBuffer(backing_buffer_ptr); });
new MeshBuffer(mesh_buffer_config, device_local_config, device_local_size, mesh_device, backing_buffer));
} else {
mesh_buffer = std::shared_ptr<MeshBuffer>(new MeshBuffer(
mesh_buffer_config,
device_local_config,
*address,
device_local_size,
mesh_device,
std::move(backing_buffer)));
mesh_buffer = std::shared_ptr<MeshBuffer>(
new MeshBuffer(mesh_buffer_config, device_local_config, address.value(), device_local_size, mesh_device));
}

mesh_buffer->allocate();
Expand All @@ -98,12 +85,19 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
}

void MeshBuffer::allocate() {
if (backing_buffer_) {
TT_FATAL(
!address_, "The address for a MeshBuffer should not explicitly be initialized when it is being allocated");
address_ = backing_buffer_->address();
} else {
TT_FATAL(address_, "A MeshBuffer should be provided a valid address if its not being allocated");
}
buffers_ = std::vector<std::vector<std::shared_ptr<Buffer>>>(
mesh_device_->num_rows(), std::vector<std::shared_ptr<Buffer>>(mesh_device_->num_cols()));

auto allocate_device_buffer_at_address = [this](const Coordinate& coord) {
std::shared_ptr<Buffer> buffer = Buffer::create(
mesh_device_,
mesh_device_->get_device(coord.row, coord.col),
address_,
device_local_size_,
device_local_config_.page_size,
Expand All @@ -116,11 +110,7 @@ void MeshBuffer::allocate() {

for (int row = 0; row < mesh_device_->num_rows(); row++) {
for (int col = 0; col < mesh_device_->num_cols(); col++) {
if (row == 0 and col == 0) {
buffers_[row][col] = backing_buffer_;
} else {
buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col});
}
buffers_[row][col] = allocate_device_buffer_at_address(Coordinate{row, col});
}
}
}
Expand Down Expand Up @@ -155,4 +145,27 @@ const ShardedBufferConfig& MeshBuffer::global_shard_spec() const {
return std::get<ShardedBufferConfig>(config_);
}

uint32_t MeshBuffer::datum_size_bytes() const {
// Limitation for now.
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query datum size for buffers sharded across the Mesh");
return this->global_shard_spec().compute_datum_size_bytes();
}

Shape2D MeshBuffer::physical_shard_shape() const {
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query physical shard shape for buffers sharded across the Mesh");
auto sharded_config = std::get<ShardedBufferConfig>(config_);
Shape2D physical_shard_shape = sharded_config.shard_shape;
if (physical_shard_shape.height() == 0) {
physical_shard_shape = {sharded_config.global_buffer_shape.height(), physical_shard_shape.width()};
}
if (physical_shard_shape.width() == 0) {
physical_shard_shape = {physical_shard_shape.height(), sharded_config.global_buffer_shape.width()};
}
return physical_shard_shape;
}

} // namespace tt::tt_metal::distributed
Loading

0 comments on commit 7bf17b9

Please sign in to comment.