Skip to content

Commit

Permalink
Uplift metal: fixed mesh device include, reshape, and memory constants
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Oct 8, 2024
1 parent cbd4e6d commit 94d52b5
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 57 deletions.
2 changes: 1 addition & 1 deletion runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#pragma clang diagnostic ignored "-Wlogical-op-parentheses"
#pragma clang diagnostic ignored "-Wundefined-inline"
#define FMT_HEADER_ONLY
#include "impl/device/mesh_device.hpp"
#include "distributed/mesh_device.hpp"
#include "impl/event/event.hpp"
#include "tt_metal/host_api.hpp"
#pragma clang diagnostic pop
Expand Down
2 changes: 1 addition & 1 deletion runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
#pragma clang diagnostic ignored "-Wc99-extensions"

#define FMT_HEADER_ONLY
#include "distributed/mesh_device.hpp"
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/mesh_device.hpp"
#include "ttnn/device.hpp"
#include "ttnn/operations/conv/conv2d/conv2d.hpp"
#include "ttnn/operations/copy.hpp"
Expand Down
17 changes: 12 additions & 5 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#define FMT_HEADER_ONLY
#include "distributed/mesh_device.hpp"
#include "host_api.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/mesh_device.hpp"
#pragma clang diagnostic pop

namespace tt::runtime::system_desc {
Expand Down Expand Up @@ -158,7 +158,8 @@ calculateDRAMUnreservedEnd(const ::tt::tt_metal::Device *device) {
std::uint32_t totalCores = deviceGridSize.x * deviceGridSize.y +
device->get_active_ethernet_cores().size();
std::uint32_t totalDramCores = dramGridSize.x * dramGridSize.y;
std::uint32_t programCarveOutPerCore = L1_UNRESERVED_BASE;
std::uint32_t programCarveOutPerCore =
device->get_base_allocator_addr(::tt::tt_metal::HalMemType::L1);
std::uint32_t totalProgramCarveOut = programCarveOutPerCore * totalCores;
// The total carve out can be interleaved between all dram channels
std::uint32_t programCarveOutDramSpace =
Expand Down Expand Up @@ -190,6 +191,11 @@ getCurrentSystemDescImpl(const ::tt::tt_metal::MeshDevice &meshDevice) {
::flatbuffers::FlatBufferBuilder fbb;

for (const ::tt::tt_metal::Device *device : devices) {
size_t l1UnreservedBase =
device->get_base_allocator_addr(::tt::tt_metal::HalMemType::L1);
size_t dramUnreservedBase =
device->get_base_allocator_addr(::tt::tt_metal::HalMemType::DRAM);

// Construct chip descriptor
::tt::target::Dim2d deviceGrid =
toFlatbuffer(device->compute_with_storage_grid_size());
Expand Down Expand Up @@ -223,9 +229,10 @@ getCurrentSystemDescImpl(const ::tt::tt_metal::MeshDevice &meshDevice) {
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
DRAM_ALIGNMENT, L1_UNRESERVED_BASE, ERISC_L1_UNRESERVED_BASE,
DRAM_UNRESERVED_BASE, dramUnreservedEnd, chipPhysicalCores,
supportedDataTypes, supportedTileSizes));
DRAM_ALIGNMENT, l1UnreservedBase,
::eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, dramUnreservedBase,
dramUnreservedEnd, chipPhysicalCores, supportedDataTypes,
supportedTileSizes));
chipDescIndices.push_back(device->id());
// Derive chip capability
::tt::target::ChipCapability chipCapability =
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/context/get_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ deriveMeshViewCoordinates(const ::ttnn::MeshDevice &meshDevice,
const std::unordered_set<uint32_t> &desiredDeviceIds,
const ::tt::target::Dim2d *meshViewShape) {
::tt::tt_metal::Coordinate topLeft, bottomRight;
for (int row = 0; row < meshDevice.num_rows(); row++) {
for (int col = 0; col < meshDevice.num_cols(); col++) {
for (size_t row = 0; row < meshDevice.num_rows(); row++) {
for (size_t col = 0; col < meshDevice.num_cols(); col++) {
const ::ttnn::Device *currDevice = meshDevice.get_device(row, col);
if (desiredDeviceIds.contains(currDevice->id())) {
topLeft.row = row;
Expand Down
46 changes: 1 addition & 45 deletions runtime/lib/ttnn/operations/data_movement/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,12 @@
#include "tt/runtime/detail/ttnn.h"

namespace tt::runtime::ttnn::operations::data_movement {

template <int32_t Rank>
static std::array<int32_t, Rank>
vectorToArray(const std::vector<int32_t> &vec) {
if (vec.size() != Rank) {
throw std::invalid_argument("Vector size does not match array size");
}
std::array<int32_t, Rank> arr;
std::copy(vec.begin(), vec.end(), arr.begin());
return arr;
}

template <int32_t Rank>
static ::ttnn::Tensor invoke_reshape(const ::ttnn::Tensor &tensor,
const std::vector<int32_t> &shape) {
return ::ttnn::reshape(tensor, vectorToArray<Rank>(shape));
}

void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
const auto *fbShape = op->shape();
std::vector<int32_t> shape(fbShape->begin(), fbShape->end());
constexpr int32_t Rank1 = 1;
constexpr int32_t Rank2 = 2;
constexpr int32_t Rank3 = 3;
constexpr int32_t Rank4 = 4;
constexpr int32_t Rank5 = 5;

::ttnn::Tensor out;
switch (fbShape->size()) {
case Rank1:
out = invoke_reshape<Rank1>(in, shape);
break;
case Rank2:
out = invoke_reshape<Rank2>(in, shape);
break;
case Rank3:
out = invoke_reshape<Rank3>(in, shape);
break;
case Rank4:
out = invoke_reshape<Rank4>(in, shape);
break;
case Rank5:
out = invoke_reshape<Rank5>(in, shape);
break;
default:
throw std::invalid_argument("Unsupported rank for reshape");
}

::ttnn::Tensor out = ::ttnn::reshape(in, shape);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::data_movement
7 changes: 4 additions & 3 deletions runtime/lib/ttnn/operations/matmul/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
utils::createMemoryConfig(op->out());
::ttnn::Tensor out = ::ttnn::operations::matmul::matmul(
lhs, rhs, /*bias=*/std::nullopt,
::ttnn::operations::matmul::Matmul{/*program_config=*/std::nullopt,
/*bcast_batch=*/std::nullopt,
outputMemoryConfig, outputDataType});
::ttnn::operations::matmul::Matmul{.output_mem_config =
outputMemoryConfig,
.output_dtype = outputDataType,
.output_tile = std::nullopt});
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::matmul
Expand Down

0 comments on commit 94d52b5

Please sign in to comment.