Skip to content

Commit

Permalink
Enable mnist sharding with layout overrides (#894)
Browse files Browse the repository at this point in the history
* Generated ToLayout ops now have suffix for location name
* Add workaround for ttnn failing to chose 1d matmul program config
  • Loading branch information
odjuricicTT authored Oct 16, 2024
1 parent 5d03b9a commit 21955b4
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 16 deletions.
27 changes: 22 additions & 5 deletions lib/Dialect/TTIR/Transforms/Layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ getLegalTensorMemoryLayout(OperandConstraint operandConstraint,
return TensorMemoryLayout::None;
}

inline Location appendInputSuffix(Location loc, int64_t operandIndex) {
if (isa<NameLoc>(loc)) {
NameLoc oldLoc = mlir::cast<NameLoc>(loc);
StringAttr newName = StringAttr::get(
loc->getContext(), oldLoc.getName().str() + "_in_" +
std::to_string(operandIndex) + "_layout");

return NameLoc::get(newName, oldLoc.getChildLoc());
}

return loc;
}

//===----------------------------------------------------------------------===//
// To layout pass
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -297,9 +310,11 @@ class TTIRLayoutDPSOperandsRewriter
mlir::cast<TTIROp>(op.getOperation())
.getOperandConstraints()[operand.getOperandNumber()])
.getValue();
auto desiredLayout = createToLayoutOp(
rewriter, op.getLoc(), operand.get(), operandConstraint,
defaultMemorySpace, defaultDeviceMemoryLayout);
Location newLoc =
appendInputSuffix(op.getLoc(), operand.getOperandNumber());
auto desiredLayout =
createToLayoutOp(rewriter, newLoc, operand.get(), operandConstraint,
defaultMemorySpace, defaultDeviceMemoryLayout);

if (desiredLayout) {
rewriter.modifyOpInPlace(op, [&]() {
Expand Down Expand Up @@ -341,9 +356,11 @@ class TTIRLayoutFuncReturnRewriter
if (isDeviceMemorySpace(initMemorySpace)) {
initMemoryLayout = defaultDeviceMemoryLayout;
}
Location newLoc =
appendInputSuffix(op.getLoc(), operand.getOperandNumber());
if (auto layout =
createToLayoutOp(rewriter, op.getLoc(), operand.get(),
initMemorySpace, initMemoryLayout, tiled);
createToLayoutOp(rewriter, newLoc, operand.get(), initMemorySpace,
initMemoryLayout, tiled);
layout) {
rewriter.modifyOpInPlace(
op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); });
Expand Down
28 changes: 20 additions & 8 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ struct Env {
constexpr static Env
#endif
get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true,
bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true)
bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true,
bool setMatmul1DProgramConfig = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true);
return Env(true, true, true, true, true);
}
#endif
// TODO(bug #272), determine correct layout by tile shape in the future
Expand All @@ -41,21 +42,32 @@ struct Env {
// instead of adding a method in runtime
bool maxpool2dPreshard;

// TODO(bug #891): ttnn::matmul doesn't chose correct program config.
bool setMatmul1DProgramConfig;

private:
constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard)
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig)
: ignoreTileShape(ignoreTileShape),
emptyOpForceRowMajor(emptyOpForceRowMajor),
fullOpForceRowMajor(fullOpForceRowMajor),
maxpool2dPreshard(maxpool2dPreshard) {}
maxpool2dPreshard(maxpool2dPreshard),
setMatmul1DProgramConfig(setMatmul1DProgramConfig) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
os << "workaround::Env{\n";
os << "\t" << "ignoreTileShape: " << env.ignoreTileShape << ",\n";
os << "\t" << "emptyOpForceRowMajor: " << env.emptyOpForceRowMajor << ",\n";
os << "\t" << "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n";
os << "\t" << "maxpool2dPreshard: " << env.maxpool2dPreshard << "\n";
os << "\t"
<< "ignoreTileShape: " << env.ignoreTileShape << ",\n";
os << "\t"
<< "emptyOpForceRowMajor: " << env.emptyOpForceRowMajor << ",\n";
os << "\t"
<< "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n";
os << "\t"
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
<< "setMatmul1DProgramConfig: " << env.setMatmul1DProgramConfig << "\n";
os << "}";
return os;
}
Expand Down
6 changes: 4 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard) {
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig) {
static const Env config(ignoreTileShape, emptyOpForceRowMajor,
fullOpForceRowMajor, maxpool2dPreshard);
fullOpForceRowMajor, maxpool2dPreshard,
setMatmul1DProgramConfig);
return config;
}
#endif
Expand Down
85 changes: 84 additions & 1 deletion runtime/lib/ttnn/operations/matmul/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,81 @@
#include "matmul.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/detail/workarounds.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include <optional>

// ANCHOR: adding_an_op_matmul_runtime_operations
namespace tt::runtime::ttnn::operations::matmul {

// This is a workaround for the lack of program config selection in ttnn.matmul.
// The logic here is temporary and totaly incompleate.
::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig
createProgramConfig(const ::tt::target::ttnn::MatmulOp *op,
ProgramContext &context,
::tt::tt_metal::MemoryConfig outputMemoryConfig) {

uint32_t numCores = outputMemoryConfig.shard_spec->grid.num_cores();
bool fuseBatch = true; // required for sharded inputs

ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id());
const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id());

// note: use ttnn::Shape::value returns a legacy tt::tt_metal::Shape object
// which does take padding into account.
uint32_t volume = 1;
for (size_t i = 0; i < lhs.shape().rank(); i++) {
volume *= lhs.shape().value[i];
}

uint32_t M =
fuseBatch ? volume / lhs.shape().value[-1] : lhs.shape().value[-2];
// uint32_t K = lhs.shape().value[-1];
uint32_t N = rhs.shape().value[-1];
bool mcastIn0 = N >= M;

uint32_t perCoreM, perCoreN;

if (mcastIn0) {
perCoreM = M / tt::constants::TILE_HEIGHT;
perCoreN = tt::div_up(tt::div_up(N, numCores), tt::constants::TILE_WIDTH);
} else {
perCoreM = tt::div_up(tt::div_up(M, numCores), tt::constants::TILE_HEIGHT);
perCoreN = N / tt::constants::TILE_WIDTH;
}

// uint32_t in0_block_w = (K / tt::constants::TILE_WIDTH) % 2 == 0 ? 2 : 1;
uint32_t in0BlockW = 1;

// These should work in most cases, but there is a logic how we can optimize
// this later.
uint32_t outSubblockH = 1, outSubblockW = 1;

assert(outputMemoryConfig.shard_spec->grid.ranges().size() == 1);
CoreCoord computeWithStorageGridSize =
outputMemoryConfig.shard_spec->grid.ranges().begin()->grid_size();
if (lhs.is_sharded()) {
CoreCoord lhs_grid_size =
lhs.shard_spec()->grid.ranges().begin()->grid_size();
if (computeWithStorageGridSize < lhs_grid_size) {
computeWithStorageGridSize = lhs_grid_size;
}
}

return ::ttnn::operations::matmul::
MatmulMultiCoreReuseMultiCast1DProgramConfig{
.compute_with_storage_grid_size = computeWithStorageGridSize,
.in0_block_w = in0BlockW,
.out_subblock_h = outSubblockH,
.out_subblock_w = outSubblockW,
.per_core_M = perCoreM,
.per_core_N = perCoreN,
.fuse_batch = true,
.fused_activation = std::nullopt,
.mcast_in0 = mcastIn0};
};

void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id());
Expand All @@ -18,9 +89,21 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) {
::ttnn::DataType outputDataType = utils::getDataType(op->out());
::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

std::optional<
::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig>
programConfig = std::nullopt;

// TODO(bug #891): ttnn::matmul doesn't chose correct program config.
if (workaround::Env::get().setMatmul1DProgramConfig &&
outputMemoryConfig.memory_layout ==
::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) {
programConfig = createProgramConfig(op, context, outputMemoryConfig);
}

::ttnn::Tensor out = ::ttnn::operations::matmul::matmul(
lhs, rhs, /*bias=*/std::nullopt,
::ttnn::operations::matmul::Matmul{/*program_config=*/std::nullopt,
::ttnn::operations::matmul::Matmul{/*program_config=*/programConfig,
/*bcast_batch=*/std::nullopt,
outputMemoryConfig, outputDataType});
tensorPool.insert_or_assign(op->out()->global_id(), out);
Expand Down
8 changes: 8 additions & 0 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def initialize_api():
choices=[True, False],
help="Disable maxpool2d preshard workaround",
)
Run.register_arg(
name="--disable-matmul-1d-program-config",
type=bool,
default=False,
choices=[True, False],
help="Disable matmul 1d program config workaround",
)
Run.register_arg(
name="binary",
type=str,
Expand Down Expand Up @@ -331,6 +338,7 @@ def _execute(binaries):
not self["--disable-empty-op-row-major"],
not self["--disable-full-op-row-major"],
not self["--disable-maxpool2d-preshard"],
not self["--disable-matmul-1d-program-config"],
)
self.logging.debug(f"setting tt runtime workaround env={workaround_env}")
self.logging.debug(f"setting torch manual seed={self['--seed']}")
Expand Down
21 changes: 21 additions & 0 deletions test/ttmlir/Dialect/TTNN/input_layout_loc_override.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: ttmlir-opt --mlir-print-debuginfo --ttir-to-ttnn-backend-pipeline="enable-optimizer=true sharding-pass-enabled=true override-output-layout=matmul_1_in_1_layout=1x1:l1:interleaved" %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("Matmul":4294967295:0)
// CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3))
// CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3))
// CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3))
// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<128x96xbf16, #l1_>, interleaved>

module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
%0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2)
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]])
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]])
// CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]])
%1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2)
return %1 : tensor<64x96xbf16>
} loc(#loc)
} loc(#loc)

#loc1 = loc("Matmul":4294967295:1)
#loc2 = loc("matmul_1"(#loc1))
42 changes: 42 additions & 0 deletions test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true sharding-pass-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,add_2_in_1_layout=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,add_6_in_1_layout=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("MNISTLinear":4294967295:0)
module @"tt-forge-graph" attributes {} {
func.func @main(%arg0: tensor<32x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<32x32xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x32xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<32x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<32x32xf32> {
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<32x32xf32, #l1_>, width_sharded>
// CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded>
%0 = tensor.empty() : tensor<32x256xf32> loc(#loc8)
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]>
%1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x784xf32>, tensor<784x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc8)
%2 = tensor.empty() : tensor<32x256xf32> loc(#loc9)
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]>
%3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x256xf32>, tensor<32x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc9)
%4 = tensor.empty() : tensor<32x256xf32> loc(#loc10)
// CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]>
%5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<32x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc10)
%6 = tensor.empty() : tensor<32x32xf32> loc(#loc11)
// CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<32x32xf32, #[[LAYOUT_2]]>
%7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x256xf32>, tensor<256x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc11)
%8 = tensor.empty() : tensor<32x32xf32> loc(#loc12)
// CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xf32, #[[LAYOUT_2]]>
%9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc12)
%10 = tensor.empty() : tensor<32x32xf32> loc(#loc13)
%11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc13)
return %11 : tensor<32x32xf32> loc(#loc7)
} loc(#loc)
} loc(#loc)
#loc1 = loc("MNISTLinear":4294967295:10)
#loc2 = loc("MNISTLinear":4294967295:8)
#loc3 = loc("MNISTLinear":4294967295:6)
#loc4 = loc("MNISTLinear":4294967295:4)
#loc5 = loc("MNISTLinear":4294967295:3)
#loc6 = loc("MNISTLinear":4294967295:2)
#loc7 = loc(unknown)
#loc8 = loc("matmul_1"(#loc1))
#loc9 = loc("add_2"(#loc2))
#loc10 = loc("relu_3"(#loc3))
#loc11 = loc("matmul_5"(#loc4))
#loc12 = loc("add_6"(#loc5))
#loc13 = loc("softmax_7"(#loc6))

0 comments on commit 21955b4

Please sign in to comment.