From 8b2c6cd16a27bd54675a018dc8a66013af3181e3 Mon Sep 17 00:00:00 2001 From: asaigal Date: Wed, 22 Jan 2025 09:07:10 -0800 Subject: [PATCH] #0: Add native 2D sharding and replication functionality to MeshBuffer - Add top level EnqueueWriteMeshBuffer and EnqueueReadMeshBuffer APIs to distributed.hpp --- .../tt_metal/distributed/test_mesh_buffer.cpp | 148 ++++++++++++++ tt_metal/distributed/distributed.hpp | 22 +++ tt_metal/distributed/mesh_buffer.cpp | 40 ++-- tt_metal/distributed/mesh_buffer.hpp | 14 ++ tt_metal/distributed/mesh_command_queue.cpp | 183 +++++++++++++++++- tt_metal/distributed/mesh_command_queue.hpp | 20 +- 6 files changed, 411 insertions(+), 16 deletions(-) diff --git a/tests/tt_metal/distributed/test_mesh_buffer.cpp b/tests/tt_metal/distributed/test_mesh_buffer.cpp index 7b8e916c0bd..aeae4e2f040 100644 --- a/tests/tt_metal/distributed/test_mesh_buffer.cpp +++ b/tests/tt_metal/distributed/test_mesh_buffer.cpp @@ -253,5 +253,153 @@ INSTANTIATE_TEST_SUITE_P( // shard_strategy ::testing::Values( TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::BLOCK_SHARDED))); + +TEST_F(MeshBufferTest, SweepShardAndConcat) { + uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); + + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::DRAM, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = true}; + std::vector global_buffer_shapes = { + {64, 128}, {128, 128}, {32, 1024}, {1024, 32}, {512, 64}, {2048, 2048}}; + std::vector shard_shapes = {{32, 32}, {32, 64}, {32, 128}, {128, 32}, {128, 32}, {512, 1024}}; + for (auto shard_orientation : {ShardOrientation::COL_MAJOR, ShardOrientation::ROW_MAJOR}) { + for (int i = 0; i < global_buffer_shapes.size(); i++) { + Shape2D global_buffer_shape = global_buffer_shapes[i]; + Shape2D shard_shape = shard_shapes[i]; + + uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t); + + ShardedBufferConfig sharded_config{ + .global_size = global_buffer_size, + .global_buffer_shape = global_buffer_shape, + .shard_shape = shard_shape, + .shard_orientation = shard_orientation, + }; + + auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get()); + std::vector src_vec = + std::vector(global_buffer_shape.height() * global_buffer_shape.width(), 0); + std::iota(src_vec.begin(), src_vec.end(), 0); + EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), mesh_buffer, src_vec); + std::vector dst_vec = {}; + EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(), dst_vec, mesh_buffer); + + EXPECT_EQ(dst_vec, src_vec); + } + } +} + +TEST_F(MeshBufferTest, RowMajorShardingAndReplication) { + uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); + + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::DRAM, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = true}; + + std::vector global_buffer_shapes = {{64, 256}, {128, 128}, {256, 2048}, {32, 512}, {512, 1024}}; + + for (int i = 0; i < global_buffer_shapes.size(); i++) { + auto global_buffer_shape = global_buffer_shapes[i]; + Shape2D shard_shape = {0, global_buffer_shape.width() / mesh_device_->num_cols()}; + // Mesh-Level Sharding Parameters for the MeshBufferView that will be read to verify correctness + Shape2D global_buffer_read_shape = { + global_buffer_shape.height() * mesh_device_->num_rows(), global_buffer_shape.width()}; + Shape2D shard_read_shape = { + global_buffer_shape.height(), global_buffer_shape.width() / mesh_device_->num_cols()}; + + uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t); + auto shard_orientation = ShardOrientation::ROW_MAJOR; + + ShardedBufferConfig sharded_config{ + .global_size = global_buffer_size, + .global_buffer_shape = global_buffer_shape, + .shard_shape = shard_shape, + .shard_orientation = shard_orientation, + }; + // Initialize the ShardedBufferConfig for reading and verifying replicated data + ShardedBufferConfig sharded_read_view_config{ + .global_size = global_buffer_read_shape.height() * global_buffer_read_shape.width() * sizeof(uint32_t), + .global_buffer_shape = global_buffer_read_shape, + .shard_shape = shard_read_shape, + .shard_orientation = shard_orientation}; + + auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get()); + std::vector src_vec = + std::vector(global_buffer_shape.height() * global_buffer_shape.width(), 0); + std::iota(src_vec.begin(), src_vec.end(), 0); + + auto mesh_buffer_read_view = MeshBuffer::create( + sharded_read_view_config, per_device_buffer_config, mesh_device_.get(), mesh_buffer->address()); + EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), mesh_buffer, src_vec); + std::vector dst_vec = + std::vector(global_buffer_read_shape.height() * global_buffer_read_shape.width(), 0); + EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(), dst_vec, mesh_buffer_read_view); + + for (int i = 0; i < dst_vec.size(); i++) { + EXPECT_EQ(dst_vec[i], i % (src_vec.size())); + } + } +} + +TEST_F(MeshBufferTest, ColMajorShardingAndReplication) { + uint32_t single_tile_size = ::tt::tt_metal::detail::TileSize(DataFormat::UInt32); + + DeviceLocalBufferConfig per_device_buffer_config{ + .page_size = single_tile_size, + .buffer_type = BufferType::DRAM, + .buffer_layout = TensorMemoryLayout::INTERLEAVED, + .bottom_up = true}; + + std::vector global_buffer_shapes = {{256, 64}, {1024, 1024}, {128, 32}, {512, 64}, {2048, 256}}; + + for (int i = 0; i < global_buffer_shapes.size(); i++) { + auto global_buffer_shape = global_buffer_shapes[i]; + Shape2D shard_shape = {global_buffer_shape.height() / mesh_device_->num_rows(), 0}; + uint32_t global_buffer_size = global_buffer_shape.height() * global_buffer_shape.width() * sizeof(uint32_t); + Shape2D global_buffer_read_shape = { + global_buffer_shape.height(), global_buffer_shape.width() * mesh_device_->num_cols()}; + Shape2D shard_read_shape = { + global_buffer_shape.height() / mesh_device_->num_rows(), global_buffer_shape.width()}; + + ShardOrientation shard_orientation = ShardOrientation::COL_MAJOR; + + ShardedBufferConfig sharded_config{ + .global_size = global_buffer_size, + .global_buffer_shape = global_buffer_shape, + .shard_shape = shard_shape, + .shard_orientation = shard_orientation, + }; + + ShardedBufferConfig sharded_read_view_config{ + .global_size = global_buffer_read_shape.height() * global_buffer_read_shape.width() * sizeof(uint32_t), + .global_buffer_shape = global_buffer_read_shape, + .shard_shape = shard_read_shape, + .shard_orientation = ShardOrientation::ROW_MAJOR}; + + auto mesh_buffer = MeshBuffer::create(sharded_config, per_device_buffer_config, mesh_device_.get()); + std::vector src_vec = + std::vector(global_buffer_shape.height() * global_buffer_shape.width(), 0); + std::iota(src_vec.begin(), src_vec.end(), 0); + + auto mesh_buffer_read_view = MeshBuffer::create( + sharded_read_view_config, per_device_buffer_config, mesh_device_.get(), mesh_buffer->address()); + + EnqueueWriteMeshBuffer(mesh_device_->mesh_command_queue(), mesh_buffer, src_vec); + std::vector dst_vec = + std::vector(global_buffer_read_shape.height() * global_buffer_read_shape.width(), 0); + EnqueueReadMeshBuffer(mesh_device_->mesh_command_queue(), dst_vec, mesh_buffer_read_view); + for (int i = 0; i < dst_vec.size(); i++) { + EXPECT_EQ( + (i / global_buffer_read_shape.width()) * global_buffer_shape.width() + i % global_buffer_shape.width(), + dst_vec[i]); + } + } +} + } // namespace } // namespace tt::tt_metal::distributed::test diff --git a/tt_metal/distributed/distributed.hpp b/tt_metal/distributed/distributed.hpp index 6e10b142f56..2cca1710342 100644 --- a/tt_metal/distributed/distributed.hpp +++ b/tt_metal/distributed/distributed.hpp @@ -46,6 +46,28 @@ void ReadShard( mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking); } +template +void EnqueueWriteMeshBuffer( + MeshCommandQueue& mesh_cq, + std::shared_ptr& mesh_buffer, + std::vector& src, + bool blocking = false) { + mesh_cq.enqueue_write_mesh_buffer(mesh_buffer, src.data(), blocking); +} + +template +void EnqueueReadMeshBuffer( + MeshCommandQueue& mesh_cq, + std::vector& dst, + std::shared_ptr& mesh_buffer, + bool blocking = true) { + TT_FATAL( + mesh_buffer->global_layout() == MeshBufferLayout::SHARDED, + "Can only read a Sharded MeshBuffer from a MeshDevice."); + dst.resize(mesh_buffer->global_shard_spec().global_size / sizeof(DType)); + mesh_cq.enqueue_read_mesh_buffer(dst.data(), mesh_buffer, blocking); +} + void Finish(MeshCommandQueue& mesh_cq); } // namespace distributed diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index b3db2c5846f..217098a573a 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -19,7 +19,7 @@ void validate_mesh_buffer_config(const MeshBufferConfig& config, const MeshDevic const auto& sharded_config = std::get(config); const auto [global_buffer_height, global_buffer_width] = sharded_config.global_buffer_shape; - const auto [shard_height, shard_width] = sharded_config.shard_shape; + const auto [shard_height, shard_width] = sharded_config.physical_shard_shape(); TT_FATAL( (global_buffer_height % shard_height == 0) and (global_buffer_width % shard_width == 0), @@ -32,13 +32,25 @@ void validate_mesh_buffer_config(const MeshBufferConfig& config, const MeshDevic const auto num_shard_rows = global_buffer_height / shard_height; const auto num_shard_cols = global_buffer_width / shard_width; - const auto num_shards = num_shard_rows * num_shard_cols; + auto num_shards = num_shard_rows * num_shard_cols; + + // The following check needs to account for shard orientation. The scaling factor for + // replication depends on which orientation we shard/replicate to when writing to device. + const auto& [height_replicated, width_replicated] = sharded_config.replicated_dims(); + if (height_replicated and width_replicated) { + // Pure replication + num_shards *= mesh_device.num_cols() * mesh_device.num_rows(); + } else if (height_replicated or width_replicated) { + // Replication along row or column dim. + num_shards *= + ((sharded_config.shard_orientation == ShardOrientation::ROW_MAJOR) * (mesh_device.num_rows()) + + (sharded_config.shard_orientation == ShardOrientation::COL_MAJOR) * (mesh_device.num_cols())); + } TT_FATAL( num_shards <= mesh_device.num_devices(), - "The number of shards must align with the mesh shape: number of shards: {}, mesh shape: ({}, {})", + "The sharded tensor does not fit on the Mesh. Num shards in buffer {}, Num Devices {}", num_shards, - mesh_device.num_rows(), - mesh_device.num_cols()); + mesh_device.num_devices()); } } // namespace @@ -54,7 +66,7 @@ std::shared_ptr MeshBuffer::create( tt::stl::overloaded{ [](const ReplicatedBufferConfig& c) { return c.size; }, [mesh_device](const ShardedBufferConfig& config) { - const auto [shard_height, shard_width] = config.shard_shape; + const auto [shard_height, shard_width] = config.physical_shard_shape(); return config.compute_datum_size_bytes() * shard_height * shard_width; }}, mesh_buffer_config); @@ -158,14 +170,14 @@ Shape2D MeshBuffer::physical_shard_shape() const { this->global_layout() == MeshBufferLayout::SHARDED, "Can only query physical shard shape for buffers sharded across the Mesh"); auto sharded_config = std::get(config_); - Shape2D physical_shard_shape = sharded_config.shard_shape; - if (physical_shard_shape.height() == 0) { - physical_shard_shape = {sharded_config.global_buffer_shape.height(), physical_shard_shape.width()}; - } - if (physical_shard_shape.width() == 0) { - physical_shard_shape = {physical_shard_shape.height(), sharded_config.global_buffer_shape.width()}; - } - return physical_shard_shape; + return sharded_config.physical_shard_shape(); +} + +std::pair MeshBuffer::replicated_dims() const { + TT_FATAL( + this->global_layout() == MeshBufferLayout::SHARDED, + "Can only query replicated dims for buffers sharded across the Mesh"); + return this->global_shard_spec().replicated_dims(); } } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_buffer.hpp b/tt_metal/distributed/mesh_buffer.hpp index dfdd6e9bacc..53e063bbfdc 100644 --- a/tt_metal/distributed/mesh_buffer.hpp +++ b/tt_metal/distributed/mesh_buffer.hpp @@ -37,6 +37,8 @@ struct ReplicatedBufferConfig { // Specifies sharded MeshBuffer. struct ShardedBufferConfig { + // Note: Only 2D sharding and replication is supported by the APIs exposed through this struct. + // This interface will likely change over time depending on the status of native ND sharding. // Global buffer size. Each device will get a fraction of this size. DeviceAddr global_size = 0; @@ -53,6 +55,14 @@ struct ShardedBufferConfig { uint32_t compute_datum_size_bytes() const { return global_size / (global_buffer_shape.height() * global_buffer_shape.width()); } + + std::pair replicated_dims() const { return {shard_shape.height() == 0, shard_shape.width() == 0}; } + + Shape2D physical_shard_shape() const { + const auto [shard_height, shard_width] = shard_shape; + const auto [global_height, global_width] = global_buffer_shape; + return Shape2D(shard_height == 0 ? global_height : shard_height, shard_width == 0 ? global_width : shard_width); + } }; enum class MeshBufferLayout : uint8_t { REPLICATED, SHARDED }; @@ -75,12 +85,16 @@ class MeshBuffer { MeshBufferLayout global_layout() const; const MeshBufferConfig& global_config() const { return config_; } + // ND Sharding is not supported today. MeshBuffer only supports 2D sharding and + // replication. Tensor sharding schemes that can be lowered to 2D configurations + // are thus supported by the MeshCommandQueue. const ShardedBufferConfig& global_shard_spec() const; const DeviceLocalBufferConfig& device_local_config() const { return device_local_config_; } std::shared_ptr get_device_buffer(const Coordinate& device_coord); uint32_t datum_size_bytes() const; Shape2D physical_shard_shape() const; + std::pair replicated_dims() const; private: MeshBuffer( diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index 1c7c33879c5..09d59864083 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -243,7 +243,7 @@ void MeshCommandQueue::enqueue_write_shard( } void MeshCommandQueue::enqueue_read_shard( - void* host_data, std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking) { + void* host_data, const std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking) { TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards."); // TODO: Add proper support for SubDevices once SubDeviceManager and allocator are moved up to MeshDevice // We should not be querying SubDevices from device 0. @@ -254,4 +254,185 @@ void MeshCommandQueue::enqueue_read_shard( this->read_shard_from_device(shard, host_data, expected_num_workers_completed, sub_device_ids); } +void MeshCommandQueue::write_sharded_buffer( + MeshBuffer& buffer, + const void* src, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids) { + auto global_buffer_shape = buffer.global_shard_spec().global_buffer_shape; + auto global_buffer_size = buffer.global_shard_spec().global_size; + + auto shard_shape = buffer.physical_shard_shape(); + auto datum_size_bytes = buffer.datum_size_bytes(); + + auto stride_size_bytes = datum_size_bytes * global_buffer_shape.width(); + auto single_read_size = datum_size_bytes * shard_shape.width(); + auto total_read_size_per_shard = single_read_size * shard_shape.height(); + + auto num_shards_x = global_buffer_shape.width() / shard_shape.width(); + auto num_shards_y = global_buffer_shape.height() / shard_shape.height(); + + uint32_t num_devices_x = buffer.device()->num_cols(); + uint32_t num_devices_y = buffer.device()->num_rows(); + + uint32_t device_x = 0; + uint32_t device_y = 0; + std::vector shard_data = std::vector(total_read_size_per_shard / sizeof(uint32_t), 0); + const auto& [height_replicated, width_replicated] = buffer.replicated_dims(); + for (std::size_t shard_y = 0; shard_y < num_shards_y; shard_y++) { + for (std::size_t shard_x = 0; shard_x < num_shards_x; shard_x++) { + auto read_offset = shard_x * single_read_size + shard_y * stride_size_bytes * shard_shape.height(); + uint32_t size_to_read = total_read_size_per_shard; + uint32_t local_offset = 0; + while (size_to_read) { + std::memcpy( + shard_data.data() + local_offset * (single_read_size / sizeof(uint32_t)), + (uint8_t*)(src) + read_offset + local_offset * stride_size_bytes, + single_read_size); + size_to_read -= single_read_size; + local_offset++; + } + + if (height_replicated and width_replicated) { + for (std::size_t replicated_device_x = 0; replicated_device_x < num_devices_x; replicated_device_x++) { + for (std::size_t replicated_device_y = 0; replicated_device_y < num_devices_y; + replicated_device_y++) { + auto device_shard_view = + buffer.get_device_buffer(Coordinate(replicated_device_y, replicated_device_x)); + this->write_shard_to_device( + device_shard_view, shard_data.data(), expected_num_workers_completed, sub_device_ids); + } + } + } else if (height_replicated or width_replicated) { + if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { + for (auto replicated_device_y = 0; replicated_device_y < num_devices_y; replicated_device_y++) { + auto device_shard_view = buffer.get_device_buffer(Coordinate(replicated_device_y, device_x)); + this->write_shard_to_device( + device_shard_view, shard_data.data(), expected_num_workers_completed, sub_device_ids); + } + device_x++; + } else { + for (auto replicated_device_x = 0; replicated_device_x < num_devices_x; replicated_device_x++) { + auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, replicated_device_x)); + this->write_shard_to_device( + device_shard_view, shard_data.data(), expected_num_workers_completed, sub_device_ids); + } + device_y++; + } + } else { + auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x)); + this->write_shard_to_device( + device_shard_view, shard_data.data(), expected_num_workers_completed, sub_device_ids); + if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { + if (++device_x == num_devices_x) { + device_x = 0; + ++device_y; + } + } else { + if (++device_y == num_devices_y) { + device_y = 0; + ++device_x; + } + } + } + } + } +} + +void MeshCommandQueue::read_sharded_buffer( + MeshBuffer& buffer, + void* dst, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids) { + const auto& [height_replicated, width_replicated] = buffer.replicated_dims(); + TT_FATAL( + not(height_replicated or width_replicated), "Cannot read a MeshBuffer that is replicated along any dimension."); + auto global_buffer_shape = buffer.global_shard_spec().global_buffer_shape; + auto shard_shape = buffer.physical_shard_shape(); + auto datum_size_bytes = buffer.datum_size_bytes(); + + auto stride_size_bytes = datum_size_bytes * global_buffer_shape.width(); + auto single_write_size = datum_size_bytes * shard_shape.width(); + auto total_write_size_per_shard = single_write_size * shard_shape.height(); + auto num_shards_x = global_buffer_shape.width() / shard_shape.width(); + auto num_shards_y = global_buffer_shape.height() / shard_shape.height(); + uint32_t num_devices_x = buffer.device()->num_cols(); + uint32_t num_devices_y = buffer.device()->num_rows(); + + uint32_t device_x = 0; + uint32_t device_y = 0; + + std::vector shard_data = std::vector(total_write_size_per_shard / sizeof(uint32_t), 0); + for (std::size_t shard_y = 0; shard_y < num_shards_y; shard_y++) { + for (std::size_t shard_x = 0; shard_x < num_shards_x; shard_x++) { + auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x)); + this->read_shard_from_device( + device_shard_view, shard_data.data(), expected_num_workers_completed, sub_device_ids); + uint32_t write_offset = shard_x * single_write_size + shard_y * stride_size_bytes * shard_shape.height(); + uint32_t size_to_write = total_write_size_per_shard; + uint32_t local_offset = 0; + while (size_to_write) { + std::memcpy( + (uint8_t*)(dst) + write_offset + local_offset * stride_size_bytes, + shard_data.data() + local_offset * (single_write_size / sizeof(uint32_t)), + single_write_size); + local_offset++; + size_to_write -= single_write_size; + } + if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { + if (++device_x == num_devices_x) { + device_x = 0; + ++device_y; + } + } else { + if (++device_y == num_devices_y) { + device_y = 0; + ++device_x; + } + } + } + } +} + +void MeshCommandQueue::enqueue_write_shard_to_sub_grid( + MeshBuffer& buffer, void* host_data, const LogicalDeviceRange& device_range, bool blocking) { + // TODO: Add proper support for SubDevices once SubDeviceManager and allocator are moved up to MeshDevice + // We should not be querying SubDevices from device 0. + auto sub_device_ids = tt::stl::Span(mesh_device_->get_device(0)->get_sub_device_ids()); + std::array expected_num_workers_completed; + expected_num_workers_completed[0] = expected_num_workers_completed_; + + if (buffer.global_layout() == MeshBufferLayout::REPLICATED) { + for (std::size_t logical_x = device_range.start_coord.x; logical_x < device_range.end_coord.x; logical_x++) { + for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y; + logical_y++) { + auto device_shard_view = buffer.get_device_buffer(Coordinate(logical_y, logical_x)); + this->write_shard_to_device( + device_shard_view, host_data, expected_num_workers_completed, sub_device_ids); + } + } + } else { + this->write_sharded_buffer(buffer, host_data, expected_num_workers_completed, sub_device_ids); + } + if (blocking) { + this->finish(); + } +} + +void MeshCommandQueue::enqueue_write_mesh_buffer( + const std::shared_ptr& buffer, void* host_data, bool blocking) { + LogicalDeviceRange mesh_device_extent({0, 0}, {buffer->device()->num_cols(), buffer->device()->num_rows()}); + this->enqueue_write_shard_to_sub_grid(*buffer, host_data, mesh_device_extent, blocking); +} + +void MeshCommandQueue::enqueue_read_mesh_buffer( + void* host_data, const std::shared_ptr& buffer, bool blocking) { + TT_FATAL( + buffer->global_layout() == MeshBufferLayout::SHARDED, "Can only read a Sharded MeshBuffer from a MeshDevice."); + auto sub_device_ids = tt::stl::Span(mesh_device_->get_device(0)->get_sub_device_ids()); + std::array expected_num_workers_completed; + expected_num_workers_completed[0] = expected_num_workers_completed_; + this->read_sharded_buffer(*buffer, host_data, expected_num_workers_completed, sub_device_ids); +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_command_queue.hpp b/tt_metal/distributed/mesh_command_queue.hpp index 034eec7682e..54f8bc7e982 100644 --- a/tt_metal/distributed/mesh_command_queue.hpp +++ b/tt_metal/distributed/mesh_command_queue.hpp @@ -23,6 +23,7 @@ class MeshCommandQueue { void populate_dispatch_core_type(); CoreCoord virtual_program_dispatch_core() const; CoreType dispatch_core_type() const; + // Helper functions for reading and writing individual shards void write_shard_to_device( std::shared_ptr& shard_view, const void* src, @@ -33,6 +34,17 @@ class MeshCommandQueue { void* dst, std::array& expected_num_workers_completed, tt::stl::Span sub_device_ids); + // Helper functions for read and write entire Sharded-MeshBuffers + void write_sharded_buffer( + MeshBuffer& buffer, + const void* src, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids); + void read_sharded_buffer( + MeshBuffer& buffer, + void* dst, + std::array& expected_num_workers_completed, + tt::stl::Span sub_device_ids); tt::tt_metal::WorkerConfigBufferMgr config_buffer_mgr_; LaunchMessageRingBufferState worker_launch_message_buffer_state_; uint32_t expected_num_workers_completed_ = 0; @@ -47,10 +59,16 @@ class MeshCommandQueue { uint32_t id() const { return id_; } WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr_; }; void enqueue_mesh_workload(MeshWorkload& mesh_workload, bool blocking); + // MeshBuffer Write APIs void enqueue_write_shard( std::shared_ptr& mesh_buffer, void* host_data, const Coordinate& coord, bool blocking); + void enqueue_write_shard_to_sub_grid( + MeshBuffer& buffer, void* host_data, const LogicalDeviceRange& device_range, bool blocking); + void enqueue_write_mesh_buffer(const std::shared_ptr& buffer, void* host_data, bool blocking); + // MeshBuffer Read APIs void enqueue_read_shard( - void* host_data, std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking); + void* host_data, const std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking); + void enqueue_read_mesh_buffer(void* host_data, const std::shared_ptr& buffer, bool blocking); void finish(); };