diff --git a/lib/Dialect/TTIR/Transforms/Layout.cpp b/lib/Dialect/TTIR/Transforms/Layout.cpp index 1c85baab7..e6bca2864 100644 --- a/lib/Dialect/TTIR/Transforms/Layout.cpp +++ b/lib/Dialect/TTIR/Transforms/Layout.cpp @@ -99,6 +99,19 @@ getLegalTensorMemoryLayout(OperandConstraint operandConstraint, return TensorMemoryLayout::None; } +inline Location appendInputSuffix(Location loc, int64_t operandIndex) { + if (isa(loc)) { + NameLoc oldLoc = mlir::cast(loc); + StringAttr newName = StringAttr::get( + loc->getContext(), oldLoc.getName().str() + "_in_" + + std::to_string(operandIndex) + "_layout"); + + return NameLoc::get(newName, oldLoc.getChildLoc()); + } + + return loc; +} + //===----------------------------------------------------------------------===// // To layout pass //===----------------------------------------------------------------------===// @@ -297,9 +310,11 @@ class TTIRLayoutDPSOperandsRewriter mlir::cast(op.getOperation()) .getOperandConstraints()[operand.getOperandNumber()]) .getValue(); - auto desiredLayout = createToLayoutOp( - rewriter, op.getLoc(), operand.get(), operandConstraint, - defaultMemorySpace, defaultDeviceMemoryLayout); + Location newLoc = + appendInputSuffix(op.getLoc(), operand.getOperandNumber()); + auto desiredLayout = + createToLayoutOp(rewriter, newLoc, operand.get(), operandConstraint, + defaultMemorySpace, defaultDeviceMemoryLayout); if (desiredLayout) { rewriter.modifyOpInPlace(op, [&]() { @@ -341,9 +356,11 @@ class TTIRLayoutFuncReturnRewriter if (isDeviceMemorySpace(initMemorySpace)) { initMemoryLayout = defaultDeviceMemoryLayout; } + Location newLoc = + appendInputSuffix(op.getLoc(), operand.getOperandNumber()); if (auto layout = - createToLayoutOp(rewriter, op.getLoc(), operand.get(), - initMemorySpace, initMemoryLayout, tiled); + createToLayoutOp(rewriter, newLoc, operand.get(), initMemorySpace, + initMemoryLayout, tiled); layout) { rewriter.modifyOpInPlace( op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); }); diff --git a/runtime/include/tt/runtime/detail/workarounds.h b/runtime/include/tt/runtime/detail/workarounds.h index a9ad82e98..fcd4e8a65 100644 --- a/runtime/include/tt/runtime/detail/workarounds.h +++ b/runtime/include/tt/runtime/detail/workarounds.h @@ -16,12 +16,13 @@ struct Env { constexpr static Env #endif get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true, - bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true) + bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true, + bool setMatmul1DProgramConfig = true) #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 ; #else { - return Env(true, true, true, true); + return Env(true, true, true, true, true); } #endif // TODO(bug #272), determine correct layout by tile shape in the future @@ -41,21 +42,32 @@ struct Env { // instead of adding a method in runtime bool maxpool2dPreshard; + // TODO(bug #891): ttnn::matmul doesn't chose correct program config. + bool setMatmul1DProgramConfig; + private: constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor, - bool fullOpForceRowMajor, bool maxpool2dPreshard) + bool fullOpForceRowMajor, bool maxpool2dPreshard, + bool setMatmul1DProgramConfig) : ignoreTileShape(ignoreTileShape), emptyOpForceRowMajor(emptyOpForceRowMajor), fullOpForceRowMajor(fullOpForceRowMajor), - maxpool2dPreshard(maxpool2dPreshard) {} + maxpool2dPreshard(maxpool2dPreshard), + setMatmul1DProgramConfig(setMatmul1DProgramConfig) {} }; inline std::ostream &operator<<(std::ostream &os, const Env &env) { os << "workaround::Env{\n"; - os << "\t" << "ignoreTileShape: " << env.ignoreTileShape << ",\n"; - os << "\t" << "emptyOpForceRowMajor: " << env.emptyOpForceRowMajor << ",\n"; - os << "\t" << "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n"; - os << "\t" << "maxpool2dPreshard: " << env.maxpool2dPreshard << "\n"; + os << "\t" + << "ignoreTileShape: " << env.ignoreTileShape << ",\n"; + os << "\t" + << "emptyOpForceRowMajor: " << env.emptyOpForceRowMajor << ",\n"; + os << "\t" + << "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n"; + os << "\t" + << "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n"; + os << "\t" + << "setMatmul1DProgramConfig: " << env.setMatmul1DProgramConfig << "\n"; os << "}"; return os; } diff --git a/runtime/lib/common/workarounds.cpp b/runtime/lib/common/workarounds.cpp index 8d3535dab..f5ddea03a 100644 --- a/runtime/lib/common/workarounds.cpp +++ b/runtime/lib/common/workarounds.cpp @@ -7,9 +7,11 @@ namespace tt::runtime::workaround { #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor, - bool fullOpForceRowMajor, bool maxpool2dPreshard) { + bool fullOpForceRowMajor, bool maxpool2dPreshard, + bool setMatmul1DProgramConfig) { static const Env config(ignoreTileShape, emptyOpForceRowMajor, - fullOpForceRowMajor, maxpool2dPreshard); + fullOpForceRowMajor, maxpool2dPreshard, + setMatmul1DProgramConfig); return config; } #endif diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index cf4ce5f7b..2e3104942 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -5,10 +5,81 @@ #include "matmul.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/detail/workarounds.h" #include "tt/runtime/ttnn/operations/utils.h" +#include // ANCHOR: adding_an_op_matmul_runtime_operations namespace tt::runtime::ttnn::operations::matmul { + +// This is a workaround for the lack of program config selection in ttnn.matmul. +// The logic here is temporary and totaly incompleate. +::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig +createProgramConfig(const ::tt::target::ttnn::MatmulOp *op, + ProgramContext &context, + ::tt::tt_metal::MemoryConfig outputMemoryConfig) { + + uint32_t numCores = outputMemoryConfig.shard_spec->grid.num_cores(); + bool fuseBatch = true; // required for sharded inputs + + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); + const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); + + // note: use ttnn::Shape::value returns a legacy tt::tt_metal::Shape object + // which does take padding into account. + uint32_t volume = 1; + for (size_t i = 0; i < lhs.shape().rank(); i++) { + volume *= lhs.shape().value[i]; + } + + uint32_t M = + fuseBatch ? volume / lhs.shape().value[-1] : lhs.shape().value[-2]; + // uint32_t K = lhs.shape().value[-1]; + uint32_t N = rhs.shape().value[-1]; + bool mcastIn0 = N >= M; + + uint32_t perCoreM, perCoreN; + + if (mcastIn0) { + perCoreM = M / tt::constants::TILE_HEIGHT; + perCoreN = tt::div_up(tt::div_up(N, numCores), tt::constants::TILE_WIDTH); + } else { + perCoreM = tt::div_up(tt::div_up(M, numCores), tt::constants::TILE_HEIGHT); + perCoreN = N / tt::constants::TILE_WIDTH; + } + + // uint32_t in0_block_w = (K / tt::constants::TILE_WIDTH) % 2 == 0 ? 2 : 1; + uint32_t in0BlockW = 1; + + // These should work in most cases, but there is a logic how we can optimize + // this later. + uint32_t outSubblockH = 1, outSubblockW = 1; + + assert(outputMemoryConfig.shard_spec->grid.ranges().size() == 1); + CoreCoord computeWithStorageGridSize = + outputMemoryConfig.shard_spec->grid.ranges().begin()->grid_size(); + if (lhs.is_sharded()) { + CoreCoord lhs_grid_size = + lhs.shard_spec()->grid.ranges().begin()->grid_size(); + if (computeWithStorageGridSize < lhs_grid_size) { + computeWithStorageGridSize = lhs_grid_size; + } + } + + return ::ttnn::operations::matmul:: + MatmulMultiCoreReuseMultiCast1DProgramConfig{ + .compute_with_storage_grid_size = computeWithStorageGridSize, + .in0_block_w = in0BlockW, + .out_subblock_h = outSubblockH, + .out_subblock_w = outSubblockW, + .per_core_M = perCoreM, + .per_core_N = perCoreN, + .fuse_batch = true, + .fused_activation = std::nullopt, + .mcast_in0 = mcastIn0}; +}; + void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); @@ -18,9 +89,21 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = utils::createMemoryConfig(op->out()); + + std::optional< + ::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig> + programConfig = std::nullopt; + + // TODO(bug #891): ttnn::matmul doesn't chose correct program config. + if (workaround::Env::get().setMatmul1DProgramConfig && + outputMemoryConfig.memory_layout == + ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + programConfig = createProgramConfig(op, context, outputMemoryConfig); + } + ::ttnn::Tensor out = ::ttnn::operations::matmul::matmul( lhs, rhs, /*bias=*/std::nullopt, - ::ttnn::operations::matmul::Matmul{/*program_config=*/std::nullopt, + ::ttnn::operations::matmul::Matmul{/*program_config=*/programConfig, /*bcast_batch=*/std::nullopt, outputMemoryConfig, outputDataType}); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 3822b5f88..7cdab749c 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -151,6 +151,13 @@ def initialize_api(): choices=[True, False], help="Disable maxpool2d preshard workaround", ) + Run.register_arg( + name="--disable-matmul-1d-program-config", + type=bool, + default=False, + choices=[True, False], + help="Disable matmul 1d program config workaround", + ) Run.register_arg( name="binary", type=str, @@ -331,6 +338,7 @@ def _execute(binaries): not self["--disable-empty-op-row-major"], not self["--disable-full-op-row-major"], not self["--disable-maxpool2d-preshard"], + not self["--disable-matmul-1d-program-config"], ) self.logging.debug(f"setting tt runtime workaround env={workaround_env}") self.logging.debug(f"setting torch manual seed={self['--seed']}") diff --git a/test/ttmlir/Dialect/TTNN/input_layout_loc_override.mlir b/test/ttmlir/Dialect/TTNN/input_layout_loc_override.mlir new file mode 100644 index 000000000..be0413df8 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/input_layout_loc_override.mlir @@ -0,0 +1,21 @@ +// RUN: ttmlir-opt --mlir-print-debuginfo --ttir-to-ttnn-backend-pipeline="enable-optimizer=true sharding-pass-enabled=true override-output-layout=matmul_1_in_1_layout=1x1:l1:interleaved" %s | FileCheck %s +#any_device = #tt.operand_constraint +#loc = loc("Matmul":4294967295:0) +// CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3)) +// CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3)) +// CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3)) +// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<128x96xbf16, #l1_>, interleaved> + +module attributes {} { + func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { + %0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2) + // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]]) + // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]]) + %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) + return %1 : tensor<64x96xbf16> + } loc(#loc) +} loc(#loc) + +#loc1 = loc("Matmul":4294967295:1) +#loc2 = loc("matmul_1"(#loc1)) diff --git a/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir b/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir new file mode 100644 index 000000000..ca345dc01 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir @@ -0,0 +1,42 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true sharding-pass-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,add_2_in_1_layout=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,add_6_in_1_layout=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#loc = loc("MNISTLinear":4294967295:0) +module @"tt-forge-graph" attributes {} { + func.func @main(%arg0: tensor<32x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<32x32xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x32xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<32x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<32x32xf32> { + // CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<32x32xf32, #l1_>, width_sharded> + // CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded> + %0 = tensor.empty() : tensor<32x256xf32> loc(#loc8) + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]> + %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x784xf32>, tensor<784x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc8) + %2 = tensor.empty() : tensor<32x256xf32> loc(#loc9) + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]> + %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x256xf32>, tensor<32x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc9) + %4 = tensor.empty() : tensor<32x256xf32> loc(#loc10) + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]> + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc10) + %6 = tensor.empty() : tensor<32x32xf32> loc(#loc11) + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<32x32xf32, #[[LAYOUT_2]]> + %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x256xf32>, tensor<256x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc11) + %8 = tensor.empty() : tensor<32x32xf32> loc(#loc12) + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xf32, #[[LAYOUT_2]]> + %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc12) + %10 = tensor.empty() : tensor<32x32xf32> loc(#loc13) + %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc13) + return %11 : tensor<32x32xf32> loc(#loc7) + } loc(#loc) +} loc(#loc) +#loc1 = loc("MNISTLinear":4294967295:10) +#loc2 = loc("MNISTLinear":4294967295:8) +#loc3 = loc("MNISTLinear":4294967295:6) +#loc4 = loc("MNISTLinear":4294967295:4) +#loc5 = loc("MNISTLinear":4294967295:3) +#loc6 = loc("MNISTLinear":4294967295:2) +#loc7 = loc(unknown) +#loc8 = loc("matmul_1"(#loc1)) +#loc9 = loc("add_2"(#loc2)) +#loc10 = loc("relu_3"(#loc3)) +#loc11 = loc("matmul_5"(#loc4)) +#loc12 = loc("add_6"(#loc5)) +#loc13 = loc("softmax_7"(#loc6))