Skip to content

Commit

Permalink
#0: Add native 2D sharding and replication functionality to MeshBuffer
Browse files Browse the repository at this point in the history
  - Add top level EnqueueWriteMeshBuffer and EnqueueReadMeshBuffer APIs
    to distributed.hpp
  • Loading branch information
tt-asaigal committed Jan 25, 2025
1 parent b9a2fa7 commit 8b2c6cd
Show file tree
Hide file tree
Showing 6 changed files with 411 additions and 16 deletions.
148 changes: 148 additions & 0 deletions tests/tt_metal/distributed/test_mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,153 @@ INSTANTIATE_TEST_SUITE_P(
// shard_strategy
::testing::Values(
TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED)));

TEST_F(MeshBufferTest, SweepShardAndConcat) {
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32);

DeviceLocalBufferConfig per_device_buffer_config{
.page_size = single_tile_size,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = true};
std::vector<Shape2D> global_buffer_shapes = {
{64, 128}, {128, 128}, {32, 1024}, {1024, 32}, {512, 64}, {2048, 2048}};
std::vector<Shape2D> shard_shapes = {{32, 32}, {32, 64}, {32, 128}, {128, 32}, {128, 32}, {512, 1024}};
for (auto shard_orientation : {ShardOrientation::COL_MAJOR, ShardOrientation::ROW_MAJOR}) {
for (int i = 0; i < global_buffer_shapes.size(); i++) {
Shape2D global_buffer_shape = global_buffer_shapes[i];
Shape2D shard_shape = shard_shapes[i];

uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t);

ShardedBufferConfig sharded_config{
.global_size = global_buffer_size,
.global_buffer_shape = global_buffer_shape,
.shard_shape = shard_shape,
.shard_orientation = shard_orientation,
};

auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get());
std::vector<uint32_t> src_vec =
std::vector<uint32_t>(global_buffer_shape.height() * global_buffer_shape.width(), 0);
std::iota(src_vec.begin(), src_vec.end(), 0);
EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), mesh_buffer, src_vec);
std::vector<uint32_t> dst_vec = {};
EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(), dst_vec, mesh_buffer);

EXPECT_EQ(dst_vec, src_vec);
}
}
}

TEST_F(MeshBufferTest, RowMajorShardingAndReplication) {
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32);

DeviceLocalBufferConfig per_device_buffer_config{
.page_size = single_tile_size,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = true};

std::vector<Shape2D> global_buffer_shapes = {{64, 256}, {128, 128}, {256, 2048}, {32, 512}, {512, 1024}};

for (int i = 0; i < global_buffer_shapes.size(); i++) {
auto global_buffer_shape = global_buffer_shapes[i];
Shape2D shard_shape = {0, global_buffer_shape.width() / mesh_device_->num_cols()};
// Mesh-Level Sharding Parameters for the MeshBufferView that will be read to verify correctness
Shape2D global_buffer_read_shape = {
global_buffer_shape.height() * mesh_device_->num_rows(), global_buffer_shape.width()};
Shape2D shard_read_shape = {
global_buffer_shape.height(), global_buffer_shape.width() / mesh_device_->num_cols()};

uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t);
auto shard_orientation = ShardOrientation::ROW_MAJOR;

ShardedBufferConfig sharded_config{
.global_size = global_buffer_size,
.global_buffer_shape = global_buffer_shape,
.shard_shape = shard_shape,
.shard_orientation = shard_orientation,
};
// Initialize the ShardedBufferConfig for reading and verifying replicated data
ShardedBufferConfig sharded_read_view_config{
.global_size = global_buffer_read_shape.height() * global_buffer_read_shape.width() * sizeof(uint32_t),
.global_buffer_shape = global_buffer_read_shape,
.shard_shape = shard_read_shape,
.shard_orientation = shard_orientation};

auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get());
std::vector<uint32_t> src_vec =
std::vector<uint32_t>(global_buffer_shape.height() * global_buffer_shape.width(), 0);
std::iota(src_vec.begin(), src_vec.end(), 0);

auto mesh_buffer_read_view = MeshBuffer::create(
sharded_read_view_config, per_device_buffer_config, mesh_device_.get(), mesh_buffer->address());
EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), mesh_buffer, src_vec);
std::vector<uint32_t> dst_vec =
std::vector<uint32_t>(global_buffer_read_shape.height() * global_buffer_read_shape.width(), 0);
EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(), dst_vec, mesh_buffer_read_view);

for (int i = 0; i < dst_vec.size(); i++) {
EXPECT_EQ(dst_vec[i], i % (src_vec.size()));
}
}
}

TEST_F(MeshBufferTest, ColMajorShardingAndReplication) {
uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32);

DeviceLocalBufferConfig per_device_buffer_config{
.page_size = single_tile_size,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = true};

std::vector<Shape2D> global_buffer_shapes = {{256, 64}, {1024, 1024}, {128, 32}, {512, 64}, {2048, 256}};

for (int i = 0; i < global_buffer_shapes.size(); i++) {
auto global_buffer_shape = global_buffer_shapes[i];
Shape2D shard_shape = {global_buffer_shape.height() / mesh_device_->num_rows(), 0};
uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t);
Shape2D global_buffer_read_shape = {
global_buffer_shape.height(), global_buffer_shape.width() * mesh_device_->num_cols()};
Shape2D shard_read_shape = {
global_buffer_shape.height() / mesh_device_->num_rows(), global_buffer_shape.width()};

ShardOrientation shard_orientation = ShardOrientation::COL_MAJOR;

ShardedBufferConfig sharded_config{
.global_size = global_buffer_size,
.global_buffer_shape = global_buffer_shape,
.shard_shape = shard_shape,
.shard_orientation = shard_orientation,
};

ShardedBufferConfig sharded_read_view_config{
.global_size = global_buffer_read_shape.height() * global_buffer_read_shape.width() * sizeof(uint32_t),
.global_buffer_shape = global_buffer_read_shape,
.shard_shape = shard_read_shape,
.shard_orientation = ShardOrientation::ROW_MAJOR};

auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get());
std::vector<uint32_t> src_vec =
std::vector<uint32_t>(global_buffer_shape.height() * global_buffer_shape.width(), 0);
std::iota(src_vec.begin(), src_vec.end(), 0);

auto mesh_buffer_read_view = MeshBuffer::create(
sharded_read_view_config, per_device_buffer_config, mesh_device_.get(), mesh_buffer->address());

EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), mesh_buffer, src_vec);
std::vector<uint32_t> dst_vec =
std::vector<uint32_t>(global_buffer_read_shape.height() * global_buffer_read_shape.width(), 0);
EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(), dst_vec, mesh_buffer_read_view);
for (int i = 0; i < dst_vec.size(); i++) {
EXPECT_EQ(
(i / global_buffer_read_shape.width()) * global_buffer_shape.width() + i % global_buffer_shape.width(),
dst_vec[i]);
}
}
}

} // namespace
} // namespace tt::tt_metal::distributed::test
22 changes: 22 additions & 0 deletions tt_metal/distributed/distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ void ReadShard(
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking);
}

template <typename DType>
void EnqueueWriteMeshBuffer(
MeshCommandQueue& mesh_cq,
std::shared_ptr<MeshBuffer>& mesh_buffer,
std::vector<DType>& src,
bool blocking = false) {
mesh_cq.enqueue_write_mesh_buffer(mesh_buffer, src.data(), blocking);
}

template <typename DType>
void EnqueueReadMeshBuffer(
MeshCommandQueue& mesh_cq,
std::vector<DType>& dst,
std::shared_ptr<MeshBuffer>& mesh_buffer,
bool blocking = true) {
TT_FATAL(
mesh_buffer->global_layout() == MeshBufferLayout::SHARDED,
"Can only read a Sharded MeshBuffer from a MeshDevice.");
dst.resize(mesh_buffer->global_shard_spec().global_size / sizeof(DType));
mesh_cq.enqueue_read_mesh_buffer(dst.data(), mesh_buffer, blocking);
}

void Finish(MeshCommandQueue& mesh_cq);

} // namespace distributed
Expand Down
40 changes: 26 additions & 14 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void validate_mesh_buffer_config(const MeshBufferConfig& config, const MeshDevic

const auto& sharded_config = std::get<ShardedBufferConfig>(config);
const auto [global_buffer_height, global_buffer_width] = sharded_config.global_buffer_shape;
const auto [shard_height, shard_width] = sharded_config.shard_shape;
const auto [shard_height, shard_width] = sharded_config.physical_shard_shape();

TT_FATAL(
(global_buffer_height % shard_height == 0) and (global_buffer_width % shard_width == 0),
Expand All @@ -32,13 +32,25 @@ void validate_mesh_buffer_config(const MeshBufferConfig& config, const MeshDevic

const auto num_shard_rows = global_buffer_height / shard_height;
const auto num_shard_cols = global_buffer_width / shard_width;
const auto num_shards = num_shard_rows * num_shard_cols;
auto num_shards = num_shard_rows * num_shard_cols;

// The following check needs to account for shard orientation. The scaling factor for
// replication depends on which orientation we shard/replicate to when writing to device.
const auto& [height_replicated, width_replicated] = sharded_config.replicated_dims();
if (height_replicated and width_replicated) {
// Pure replication
num_shards *= mesh_device.num_cols() * mesh_device.num_rows();
} else if (height_replicated or width_replicated) {
// Replication along row or column dim.
num_shards *=
((sharded_config.shard_orientation == ShardOrientation::ROW_MAJOR) * (mesh_device.num_rows()) +
(sharded_config.shard_orientation == ShardOrientation::COL_MAJOR) * (mesh_device.num_cols()));
}
TT_FATAL(
num_shards <= mesh_device.num_devices(),
"The number of shards must align with the mesh shape: number of shards: {}, mesh shape: ({}, {})",
"The sharded tensor does not fit on the Mesh. Num shards in buffer {}, Num Devices {}",
num_shards,
mesh_device.num_rows(),
mesh_device.num_cols());
mesh_device.num_devices());
}

} // namespace
Expand All @@ -54,7 +66,7 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
tt::stl::overloaded{
[](const ReplicatedBufferConfig& c) { return c.size; },
[mesh_device](const ShardedBufferConfig& config) {
const auto [shard_height, shard_width] = config.shard_shape;
const auto [shard_height, shard_width] = config.physical_shard_shape();
return config.compute_datum_size_bytes() * shard_height * shard_width;
}},
mesh_buffer_config);
Expand Down Expand Up @@ -158,14 +170,14 @@ Shape2D MeshBuffer::physical_shard_shape() const {
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query physical shard shape for buffers sharded across the Mesh");
auto sharded_config = std::get<ShardedBufferConfig>(config_);
Shape2D physical_shard_shape = sharded_config.shard_shape;
if (physical_shard_shape.height() == 0) {
physical_shard_shape = {sharded_config.global_buffer_shape.height(), physical_shard_shape.width()};
}
if (physical_shard_shape.width() == 0) {
physical_shard_shape = {physical_shard_shape.height(), sharded_config.global_buffer_shape.width()};
}
return physical_shard_shape;
return sharded_config.physical_shard_shape();
}

std::pair<bool, bool> MeshBuffer::replicated_dims() const {
TT_FATAL(
this->global_layout() == MeshBufferLayout::SHARDED,
"Can only query replicated dims for buffers sharded across the Mesh");
return this->global_shard_spec().replicated_dims();
}

} // namespace tt::tt_metal::distributed
14 changes: 14 additions & 0 deletions tt_metal/distributed/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct ReplicatedBufferConfig {

// Specifies sharded MeshBuffer.
struct ShardedBufferConfig {
// Note: Only 2D sharding and replication is supported by the APIs exposed through this struct.
// This interface will likely change over time depending on the status of native ND sharding.
// Global buffer size. Each device will get a fraction of this size.
DeviceAddr global_size = 0;

Expand All @@ -53,6 +55,14 @@ struct ShardedBufferConfig {
uint32_t compute_datum_size_bytes() const {
return global_size / (global_buffer_shape.height() * global_buffer_shape.width());
}

std::pair<bool, bool> replicated_dims() const { return {shard_shape.height() == 0, shard_shape.width() == 0}; }

Shape2D physical_shard_shape() const {
const auto [shard_height, shard_width] = shard_shape;
const auto [global_height, global_width] = global_buffer_shape;
return Shape2D(shard_height == 0 ? global_height : shard_height, shard_width == 0 ? global_width : shard_width);
}
};

enum class MeshBufferLayout : uint8_t { REPLICATED, SHARDED };
Expand All @@ -75,12 +85,16 @@ class MeshBuffer {

MeshBufferLayout global_layout() const;
const MeshBufferConfig& global_config() const { return config_; }
// ND Sharding is not supported today. MeshBuffer only supports 2D sharding and
// replication. Tensor sharding schemes that can be lowered to 2D configurations
// are thus supported by the MeshCommandQueue.
const ShardedBufferConfig& global_shard_spec() const;
const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; }

std::shared_ptr<Buffer> get_device_buffer(const Coordinate& device_coord);
uint32_t datum_size_bytes() const;
Shape2D physical_shard_shape() const;
std::pair<bool, bool> replicated_dims() const;

private:
MeshBuffer(
Expand Down
Loading

0 comments on commit 8b2c6cd

Please sign in to comment.