From 8220beef4fff756a0e8c4a61bec5448ba07c218e Mon Sep 17 00:00:00 2001 From: Stefan Djordjevic Date: Wed, 20 Nov 2024 15:42:07 +0000 Subject: [PATCH] Adding workaround op interface --- include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt | 1 + include/ttmlir/Dialect/TTNN/IR/TTNNBase.td | 3 +- include/ttmlir/Dialect/TTNN/IR/TTNNOps.h | 1 + include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 4 +- .../Dialect/TTNN/IR/TTNNWorkaroundInterface.h | 22 + .../TTNN/IR/TTNNWorkaroundInterface.td | 46 ++ .../ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h | 340 ++++++++++++++ .../Dialect/TTNN/Pipelines/TTNNPipelines.h | 5 + .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 8 + .../Dialect/TTNN/Utils/TransformUtils.h | 17 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 37 +- lib/Dialect/TTNN/IR/CMakeLists.txt | 2 + .../TTNN/IR/TTNNWorkaroundInterface.cpp | 139 ++++++ lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 9 + lib/Dialect/TTNN/Transforms/CMakeLists.txt | 5 +- .../TTNN/Transforms/TTNNWorkarounds.cpp | 434 ++++++++++++++++++ lib/Dialect/TTNN/Utils/CMakeLists.txt | 4 +- lib/Dialect/TTNN/Utils/TransformUtils.cpp | 30 ++ .../Workarounds/simple_workaround.mlir | 31 ++ 19 files changed, 1104 insertions(+), 34 deletions(-) create mode 100644 include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h create mode 100644 include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td create mode 100644 include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h create mode 100644 include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h create mode 100644 lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp create mode 100644 lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp create mode 100644 lib/Dialect/TTNN/Utils/TransformUtils.cpp create mode 100644 test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir diff --git a/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt b/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt index cfd65fe8db..fbf68f69dd 100644 --- a/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt +++ b/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc) add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc) add_mlir_interface(TTNNOpModelInterface) +add_mlir_interface(TTNNWorkaroundInterface) set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td) mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td b/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td index b1821c8f1b..34d3daf9cc 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td @@ -7,6 +7,7 @@ include "mlir/IR/OpBase.td" include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td" +include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td" //===----------------------------------------------------------------------===// // TTNN dialect definition. @@ -44,6 +45,6 @@ def TTNN_Dialect : Dialect { //===----------------------------------------------------------------------===// class TTNN_Op traits = []> : - Op; + Op; #endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h index e66fab65a3..457c7722bb 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h @@ -18,6 +18,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h" #define GET_OP_CLASSES #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h.inc" diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 910ed7dfd9..766583f413 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -325,7 +325,7 @@ def TTNN_Expm1Op: TTNN_ElementwiseUnaryOp<"expm1"> { }]; } -class TTIR_ElementwiseUnaryWithFloatParameterOp traits = []> : +class TTNN_ElementwiseUnaryWithFloatParameterOp traits = []> : TTNN_ElementwiseUnaryOp { let summary = "Eltwise unary op with the float parameter."; let description = [{ @@ -345,7 +345,7 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp tra ]; } -def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { +def TTNN_LeakyReluOp : TTNN_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { let summary = "Eltwise leaky relu operation."; let description = [{ The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h new file mode 100644 index 0000000000..f2ba0d17f3 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_H +#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_H + +#include "mlir/IR/Operation.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" + +namespace mlir::tt::ttnn::wa { +// Gets default operand workarounds for the given operation. This method is +// called from the TTNNWorkaroundInterface and its getOperandsWorkarounds +// method. +TTNNOperandsWorkarounds getDefaultOperandWorkarounds(Operation *op); + +// Verifies the TTNNWorkaroundInterface +mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op); +} // namespace mlir::tt::ttnn::wa + +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h.inc" + +#endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td new file mode 100644 index 0000000000..777f6182af --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD +#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD + +include "mlir/IR/OpBase.td" + +// This interface is used to specify workarounds for TTNN operations. +def TTNN_WorkaroundInterface : OpInterface<"TTNNWorkaroundInterface"> { + let cppNamespace = "::mlir::tt::ttnn::wa"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the workarounds associated with each operand and result of this operation. + If the operation is a Destination-Passing Style (DPS) operation, the same workarounds + must apply to both the DPS initial operands and the operation results. These constraints + are verified through the interface verifier. + + For example, consider the following ttnn operations: + %0 = "ttnn.emptyOp"() : () -> tensor<1x1xf32> + %1 = "ttnn.abs"(%arg0, %0) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> + + In this example: + - The ttnn.abs operation has two input operand workarounds. + - It has one output operand workaround. + - The output workaround must match the workaround for the second input operand, + ensuring consistency as required by the DPS pattern. + }], + /*retTy=*/"TTNNOperandsWorkarounds", + /*methodName=*/"getOperandsWorkarounds", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDefaultOperandWorkarounds(this->getOperation()); + }] + >, + ]; + + let verify = [{ + return verifyTTNNWorkaroundInterface($_op); + }]; +} + +#endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h new file mode 100644 index 0000000000..ff3d68c049 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h @@ -0,0 +1,340 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_IR_TTNN_LAYOUT_OPERAND_WORKAROUNDS_H +#define TTMLIR_DIALECT_TTNN_IR_TTNN_LAYOUT_OPERAND_WORKAROUNDS_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include +#include + +namespace mlir::tt::ttnn::wa { + +// Class that encapsulates tensor layout workaround. +// Possible workaround values are: +// - Tile +// - RowMajor +class TTNNTensorLayoutWorkaround { +public: + // Default constructor. It wrapes Layout enum inside shared_ptr. + // Its null value represents no workaround. + TTNNTensorLayoutWorkaround() { this->layoutWorkaround = nullptr; } + + // Constructor that takes Layout enum and wraps it inside shared_ptr. + TTNNTensorLayoutWorkaround(Layout layoutWorkaround) { + this->layoutWorkaround = std::make_shared(layoutWorkaround); + } + + // Returns shared_ptr to Layout enum. + std::shared_ptr getTensorLayoutWorkaround() const { + return layoutWorkaround; + } + + // Equality operator. + bool operator==(const TTNNTensorLayoutWorkaround &rhs) const { + if (this->layoutWorkaround && rhs.layoutWorkaround) { + return *this->layoutWorkaround == *rhs.layoutWorkaround; + } + + return !this->layoutWorkaround && !rhs.layoutWorkaround; + } + + // Inequality operator. + bool operator!=(const TTNNTensorLayoutWorkaround &rhs) const { + return !(*this == rhs); + } + +private: + // Shared pointer to Layout enum. + std::shared_ptr layoutWorkaround; +}; + +// Class that encapsulates tensor buffer type workaround. +// Possible workaround values are: +// - SystemMemory +// - DRAM +// - L1 +class TTTNNTensorBufferTypeWorkaround { +public: + // Default constructor. It wraps BufferType enum inside shared_ptr. + // Its null value represents no workaround. + TTTNNTensorBufferTypeWorkaround() { + this->tensorBufferTypeWorkaround = nullptr; + } + + // Constructor that takes BufferType enum and wraps it inside shared_ptr. + TTTNNTensorBufferTypeWorkaround(BufferType tensorBufferType) { + this->tensorBufferTypeWorkaround = + std::make_shared(tensorBufferType); + } + + // Returns shared_ptr to BufferType enum. + std::shared_ptr getTensorBufferTypeWorkaround() const { + return tensorBufferTypeWorkaround; + } + + // Equality operator. + bool operator==(const TTTNNTensorBufferTypeWorkaround &rhs) const { + if (this->tensorBufferTypeWorkaround && rhs.tensorBufferTypeWorkaround) { + return *this->tensorBufferTypeWorkaround == + *rhs.tensorBufferTypeWorkaround; + } + + return !this->tensorBufferTypeWorkaround && !rhs.tensorBufferTypeWorkaround; + } + + // Inequality operator. + bool operator!=(const TTTNNTensorBufferTypeWorkaround &rhs) const { + return !(*this == rhs); + } + +private: + // Shared pointer to BufferType enum. + std::shared_ptr tensorBufferTypeWorkaround; +}; + +// Class that encapsulates tensor memory layout workaround. +// Possible workaround values are: +// - Interleaved +// - SingleBank +// - HeightSharded +// - WidthSharded +// - BlockSharded +class TTNNTensorMemoryLayoutWorkaround { +public: + // Default constructor. It wrapes TensorMemoryLayout enum inside shared_ptr. + // Its null value represents no workaround. + TTNNTensorMemoryLayoutWorkaround() { + this->tensorMemoryLayoutWorkaround = nullptr; + } + + // Constructor that takes TensorMemoryLayout enum and wraps it inside + // shared_ptr. + TTNNTensorMemoryLayoutWorkaround(TensorMemoryLayout memoryLayoutWorkaround) { + this->tensorMemoryLayoutWorkaround = + std::make_shared(memoryLayoutWorkaround); + } + + // Returns shared_ptr to TensorMemoryLayout enum. + std::shared_ptr getTensorMemoryLayoutWorkaround() const { + return tensorMemoryLayoutWorkaround; + } + + // Equality operator. + bool operator==(const TTNNTensorMemoryLayoutWorkaround &rhs) const { + if (this->tensorMemoryLayoutWorkaround && + rhs.tensorMemoryLayoutWorkaround) { + return *this->tensorMemoryLayoutWorkaround == + *rhs.tensorMemoryLayoutWorkaround; + } + + return !this->tensorMemoryLayoutWorkaround && + !rhs.tensorMemoryLayoutWorkaround; + } + + // Inequality operator. + bool operator!=(const TTNNTensorMemoryLayoutWorkaround &rhs) const { + return !(*this == rhs); + } + +private: + // Shared pointer to TensorMemoryLayout enum. + std::shared_ptr tensorMemoryLayoutWorkaround; +}; + +// Class that encapsulates operand workarounds. +// It contains tensor layout, tensor buffer type and tensor memory layout +// workarounds. +class TTNNOperandWorkarounds { +public: + // Default constructor with no workarounds. + TTNNOperandWorkarounds() {} + + // Constructor that takes tensor layout, tensor buffer type and tensor memory + // layout workarounds. + TTNNOperandWorkarounds( + TTNNTensorLayoutWorkaround tensorLayoutWorkaround, + TTTNNTensorBufferTypeWorkaround tensorBufferTypeWorkaround, + TTNNTensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround) { + this->tensorLayoutWorkaround = tensorLayoutWorkaround; + this->tensorBufferTypeWorkaround = tensorBufferTypeWorkaround; + this->tensorMemoryLayoutWorkaround = tensorMemoryLayoutWorkaround; + } + + // Returns tensor layout workaround. + TTNNTensorLayoutWorkaround getTensorLayoutWorkaround() { + return tensorLayoutWorkaround; + } + + // Returns tensor buffer type workaround. + TTTNNTensorBufferTypeWorkaround getTensorBufferTypeWorkaround() { + return tensorBufferTypeWorkaround; + } + + // Returns tensor memory layout workaround. + TTNNTensorMemoryLayoutWorkaround getTensorMemoryLayoutWorkaround() { + return tensorMemoryLayoutWorkaround; + } + + // Equality operator. + bool operator==(const TTNNOperandWorkarounds &rhs) const { + return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround && + tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround && + tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround; + } + + // Inequality operator. + bool operator!=(const TTNNOperandWorkarounds &rhs) const { + return !(*this == rhs); + } + + // Returns true if any of the workarounds is set. + bool hasAnyWorkaround() { + return tensorLayoutWorkaround.getTensorLayoutWorkaround() || + tensorBufferTypeWorkaround.getTensorBufferTypeWorkaround() || + tensorMemoryLayoutWorkaround.getTensorMemoryLayoutWorkaround(); + } + +private: + // Tensor layout workaround. + TTNNTensorLayoutWorkaround tensorLayoutWorkaround; + // Tensor buffer type workaround. + TTTNNTensorBufferTypeWorkaround tensorBufferTypeWorkaround; + // Tensor memory layout workaround. + TTNNTensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround; +}; + +// Class that encapsulates operands workarounds. +// It contains input and output workarounds for operands. +class TTNNOperandsWorkarounds { +public: + // Default constructor with no workarounds. + TTNNOperandsWorkarounds() {} + + // Constructor that takes input and output workarounds for operands. + TTNNOperandsWorkarounds( + std::vector inputOperandWorkarounds, + std::vector outputOperandWorkarounds) + : inputOperandWorkarounds(inputOperandWorkarounds), + outputOperandWorkarounds(outputOperandWorkarounds) {} + + // Returns input operand workarounds. + std::vector getInputOperandWorkarounds() const { + return inputOperandWorkarounds; + } + + // Returns output operand workarounds. + std::vector getOutputOperandWorkarounds() const { + return outputOperandWorkarounds; + } + + // Adds input operand workaround. + TTNNOperandsWorkarounds & + addInputOperandWorkaround(TTNNOperandWorkarounds inputOperandWorkaround) { + inputOperandWorkarounds.emplace_back(inputOperandWorkaround); + return *this; + } + + // Adds output operand workaround. + TTNNOperandsWorkarounds & + addOutputOperandWorkaround(TTNNOperandWorkarounds outputOperandWorkaround) { + outputOperandWorkarounds.emplace_back(outputOperandWorkaround); + return *this; + } + +private: + // Workarounds for input operands. + std::vector inputOperandWorkarounds; + // Workarounds for output operands. + std::vector outputOperandWorkarounds; +}; + +// Class that provides factory methods for creating workarounds. +class WorkaroundFactory { +public: + // Tensor layout factory methods + static TTNNTensorLayoutWorkaround createTileTTNNTensorLayoutWorkaround() { + return TTNNTensorLayoutWorkaround(Layout::Tile); + } + + static TTNNTensorLayoutWorkaround createRowMajorTTNNTensorLayoutWorkaround() { + return TTNNTensorLayoutWorkaround(Layout::RowMajor); + } + + static TTNNTensorLayoutWorkaround createDefaultTTNNTensorLayoutWorkaround() { + return TTNNTensorLayoutWorkaround(); + } + + // Tensor buffer type factory methods + static TTTNNTensorBufferTypeWorkaround + createSystemMemoryTTTNNTensorBufferTypeWorkaround() { + return TTTNNTensorBufferTypeWorkaround(BufferType::SystemMemory); + } + + static TTTNNTensorBufferTypeWorkaround + createDeviceDRAMTTTNNTensorBufferTypeWorkaround() { + return TTTNNTensorBufferTypeWorkaround(BufferType::DRAM); + } + + static TTTNNTensorBufferTypeWorkaround + createDeviceL1TTTNNTensorBufferTypeWorkaround() { + return TTTNNTensorBufferTypeWorkaround(BufferType::L1); + } + + static TTTNNTensorBufferTypeWorkaround + createDefaultTTTNNTensorBufferTypeWorkaround() { + return TTTNNTensorBufferTypeWorkaround(); + } + + // Tensor memory layout factory methods + static TTNNTensorMemoryLayoutWorkaround + createInterleavedTTNNTensorMemoryLayoutWorkaround() { + return TTNNTensorMemoryLayoutWorkaround(TensorMemoryLayout::Interleaved); + } + + static TTNNTensorMemoryLayoutWorkaround + createSingleBankTTNNTensorMemoryLayoutWorkaround() { + return TTNNTensorMemoryLayoutWorkaround(TensorMemoryLayout::SingleBank); + } + + static TTNNTensorMemoryLayoutWorkaround + createHeightShardedTTNNTensorMemoryLayoutWorkaround() { + return TTNNTensorMemoryLayoutWorkaround(TensorMemoryLayout::HeightSharded); + } + + static TTNNTensorMemoryLayoutWorkaround + createWidthShardedTTNNTensorMemoryLayoutWorkaround() { + return TTNNTensorMemoryLayoutWorkaround(TensorMemoryLayout::WidthSharded); + } + + static TTNNTensorMemoryLayoutWorkaround + createBlockShardedTTNNTensorMemoryLayoutWorkaround() { + return TTNNTensorMemoryLayoutWorkaround(TensorMemoryLayout::BlockSharded); + } + + static TTNNTensorMemoryLayoutWorkaround + createDefaultTTNNTensorMemoryLayoutWorkaround() { + return TTNNTensorMemoryLayoutWorkaround(); + } + + // Operand workarounds factory methods + static TTNNOperandWorkarounds createDefaultTTNNOperandWorkarounds() { + return TTNNOperandWorkarounds(); + } + + // Operands workarounds factory methods + static TTNNOperandsWorkarounds + createDefaultTTNNOperandsWorkarounds(int inputSize, int outputSize) { + std::vector inputOperandWorkarounds( + inputSize, createDefaultTTNNOperandWorkarounds()); + std::vector outputOperandWorkarounds( + outputSize, createDefaultTTNNOperandWorkarounds()); + return TTNNOperandsWorkarounds(inputOperandWorkarounds, + outputOperandWorkarounds); + } +}; + +} // namespace mlir::tt::ttnn::wa + +#endif diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 48c723e1cd..7302c57aa4 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -111,6 +111,11 @@ struct TTIRToTTNNBackendPipelineOptions ListOption meshShape{ *this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")}; + + // Option to enable/disable the workaround pass + Option workaroundPassEnabled{*this, "enable-workaround-pass", + llvm::cl::desc("Enable workaround pass."), + llvm::cl::init(false)}; }; // TTIR to EmitC pipeline options. diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index c29d01f7e4..8aa0880854 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -28,4 +28,12 @@ def TTNNLayout : Pass<"ttnn-layout", "::mlir::ModuleOp"> { }]; } +def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> { + let summary = "Apply workarounds to the IR."; + let description = [{ + This pass applies necessary TTNN workarounds to the IR in order to create + a valid TTNN flatbuffer. + }]; +} + #endif diff --git a/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h new file mode 100644 index 0000000000..4b9cd7bfd9 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_TRANSFORM_UTILS_H +#define TTMLIR_DIALECT_TTNN_UTILS_TRANSFORM_UTILS_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" + +namespace mlir::tt::ttnn::utils { +// Get or insert device for the given operation. +mlir::Value getOrInsertDevice(mlir::PatternRewriter &rewriter, + mlir::Operation *op); +} // namespace mlir::tt::ttnn::utils + +#endif diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 12e29a9609..ccb614d596 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -9,6 +9,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Types/Types.h" +#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -28,27 +29,6 @@ using namespace mlir::tt; namespace { -// Gets or inserts a GetDeviceOp at the top of the current block of the given -// operation. -static Value getOrInsertDevice(ConversionPatternRewriter &rewriter, - Operation *op) { - Block *block = op->getBlock(); - for (auto &op : block->getOperations()) { - if (auto deviceOp = dyn_cast(op)) { - return deviceOp.getResult(); - } - } - - DeviceAttr deviceAttr = getCurrentScopeDevice(op); - auto currentInsertionPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(block, block->begin()); - auto deviceOp = rewriter.create( - op->getLoc(), rewriter.getType(deviceAttr), - ttnn::MeshShapeAttr::get(op->getContext(), 1, 1)); - rewriter.restoreInsertionPoint(currentInsertionPoint); - return deviceOp.getResult(); -} - class TensorEmptyConversionPattern : public OpConversionPattern { public: @@ -100,7 +80,7 @@ class TensorEmptyConversionPattern // Create MemoryConfigAttr // - auto device = getOrInsertDevice(rewriter, op); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get( op.getContext(), ttnn::TensorMemoryLayoutAttr::get(op.getContext(), memLayout), @@ -199,7 +179,8 @@ class ToLayoutOpConversionPattern rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(result), adaptor.getInput(), outputLayout, outputDataType, outputMemConfigAttr, - isOutputOnHost ? nullptr : getOrInsertDevice(rewriter, op)); + isOutputOnHost ? nullptr + : ::ttnn::utils::getOrInsertDevice(rewriter, op)); return success(); } @@ -525,7 +506,7 @@ class ConstantOpConversionPattern } if (valueAttr.isSplat()) { - Value device = getOrInsertDevice(rewriter, op); + Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); float fillValue = valueAttr.getElementType().isInteger() ? getIntegerValue(valueAttr) @@ -627,7 +608,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { matchAndRewrite(ttir::Conv2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto device = getOrInsertDevice(rewriter, op); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); auto kernel_ty = mlir::cast(adaptor.getWeight().getType()); llvm::ArrayRef kernel_shape = kernel_ty.getShape(); @@ -725,7 +706,7 @@ class MaxPool2dOpConversionPattern "TTNN max_pool2d does not support padding top/bottom/left/right " "separately"); - auto device = getOrInsertDevice(rewriter, op); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); auto input_ty = mlir::cast(adaptor.getInput().getType()); llvm::ArrayRef input_shape = input_ty.getShape(); @@ -871,7 +852,7 @@ class SubtractOpConversionPattern // addOp(lhs, negOp(rhs)) } else { - Value device = getOrInsertDevice(rewriter, srcOp); + Value device = ::ttnn::utils::getOrInsertDevice(rewriter, srcOp); tensor::EmptyOp negEmptyOp = rewriter.create( srcOp.getLoc(), this->getTypeConverter()->convertType(rhsType), device); @@ -897,7 +878,7 @@ class AllGatherOpConversionPattern ConversionPatternRewriter &rewriter) const override { RankedTensorType type = mlir::cast(adaptor.getInput().getType()); - Value device = getOrInsertDevice(rewriter, op); + Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); tensor::EmptyOp emptyOp = rewriter.create( op.getLoc(), this->getTypeConverter()->convertType(type), device); diff --git a/lib/Dialect/TTNN/IR/CMakeLists.txt b/lib/Dialect/TTNN/IR/CMakeLists.txt index 1620e96b5c..54134a4862 100644 --- a/lib/Dialect/TTNN/IR/CMakeLists.txt +++ b/lib/Dialect/TTNN/IR/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTTNNDialect TTNNOps.cpp TTNNOpModelInterface.cpp TTNNOpsTypes.cpp + TTNNWorkaroundInterface.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir @@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRTTNNDialect DEPENDS MLIRTTNNOpsIncGen MLIRTTOpsIncGen + MLIRTTNNWorkaroundInterfaceIncGen LINK_LIBS PUBLIC TTMLIRTTNNUtils diff --git a/lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp b/lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp new file mode 100644 index 0000000000..f6717a5bf9 --- /dev/null +++ b/lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp @@ -0,0 +1,139 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" +#include + +namespace mlir::tt::ttnn::wa { +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp.inc" + +// Verifier function for TTNN Workaround Interface +mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op) { + TTNNWorkaroundInterface workaroundOp = + mlir::cast(op); + + // Verify that the number of input and output operand workarounds is the same + // as the number of tensor operands and tensor results + size_t tensorInputs = 0; + size_t tensorResults = 0; + + // Count the number of tensor input operands including DPS inits + for (auto operand : op->getOperands()) { + if (mlir::isa<::mlir::RankedTensorType>(operand.getType())) { + tensorInputs++; + } + } + + // Count the number of tensor results + for (auto result : op->getResults()) { + if (mlir::isa<::mlir::RankedTensorType>(result.getType())) { + tensorResults++; + } + } + + TTNNOperandsWorkarounds workarounds = workaroundOp.getOperandsWorkarounds(); + + if (workarounds.getInputOperandWorkarounds().size() != tensorInputs) { + return op->emitOpError("Number of input operand workarounds does not match " + "the number of tensor inputs"); + } + + if (workarounds.getOutputOperandWorkarounds().size() != tensorResults) { + return op->emitOpError("Number of output operand workarounds does not " + "match the number of tensor results"); + } + + // For DPS ops, verify that the output workaround is the same as the input + // init workaround + if (mlir::isa(op)) { + DestinationStyleOpInterface dpsOp = + mlir::dyn_cast(op); + + // Go through all the operands and for each DPS init operand, check if the + // output workaround is the same + int resultIndex = 0; + for (size_t i = 0; i < op->getNumOperands(); i++) { + OpOperand &operand = op->getOpOperand(i); + + // Check only RankedTensorType operands + if (mlir::isa<::mlir::RankedTensorType>(operand.get().getType()) && + dpsOp.isDpsInit(&operand)) { + if (workarounds.getOutputOperandWorkarounds()[resultIndex] != + workarounds.getInputOperandWorkarounds()[i]) { + return op->emitOpError() << "DPS output workaround does not match " + "the input DPS init operand workaround " + << i << " and " << resultIndex; + } + resultIndex++; + } + } + } + + // All checks passed, return success + return mlir::success(); +} + +// Operand workarounds are defined for each operand and result of the operation. +// If the operation is a DPS operation, same workarounds must be applied for the +// DPS inits and DPS op outputs. All of this is verified in the interface +// verifier. For example, if we have a following ttnn operations: +// +// %0 = "ttnn.emptyOp"() : () -> tensor<1x1xf32> +// %1 = "ttnn.abs"(%arg0, %0) : (tensor<1x1xf32>, tensor<1x1xf32>) -> +// tensor<1x1xf32> +// +// In this example, we will have 2 input operand workarounds and 1 output +// operand workaround, hence the output workaround must be the same as for the +// second input operand. +TTNNOperandsWorkarounds getDefaultOperandWorkarounds(Operation *op) { + // Special case empty op + // Empty op currently only supports creation in row major layout + if (mlir::isa<::mlir::tt::ttnn::EmptyOp>(op)) { + return WorkaroundFactory::createDefaultTTNNOperandsWorkarounds(0, 0) + .addOutputOperandWorkaround(TTNNOperandWorkarounds( + WorkaroundFactory::createRowMajorTTNNTensorLayoutWorkaround(), + WorkaroundFactory::createDefaultTTTNNTensorBufferTypeWorkaround(), + WorkaroundFactory:: + createDefaultTTNNTensorMemoryLayoutWorkaround())); + } + + if (mlir::dyn_cast<::mlir::tt::ttnn::AbsOp>(op)) { + return WorkaroundFactory::createDefaultTTNNOperandsWorkarounds(0, 0) + .addInputOperandWorkaround(TTNNOperandWorkarounds( + WorkaroundFactory::createTileTTNNTensorLayoutWorkaround(), + WorkaroundFactory::createDefaultTTTNNTensorBufferTypeWorkaround(), + WorkaroundFactory::createDefaultTTNNTensorMemoryLayoutWorkaround())) + .addInputOperandWorkaround(TTNNOperandWorkarounds( + WorkaroundFactory::createTileTTNNTensorLayoutWorkaround(), + WorkaroundFactory::createDefaultTTTNNTensorBufferTypeWorkaround(), + WorkaroundFactory::createDefaultTTNNTensorMemoryLayoutWorkaround())) + .addOutputOperandWorkaround(TTNNOperandWorkarounds( + WorkaroundFactory::createTileTTNNTensorLayoutWorkaround(), + WorkaroundFactory::createDefaultTTTNNTensorBufferTypeWorkaround(), + WorkaroundFactory:: + createDefaultTTNNTensorMemoryLayoutWorkaround())); + } + + size_t tensorInputs = 0; + size_t tensorResults = 0; + + // Count the number of tensor input operands including DPS inits + for (auto operand : op->getOperands()) { + if (mlir::isa<::mlir::RankedTensorType>(operand.getType())) { + tensorInputs++; + } + } + + // Count the number of tensor results + for (auto result : op->getResults()) { + if (mlir::isa<::mlir::RankedTensorType>(result.getType())) { + tensorResults++; + } + } + + return WorkaroundFactory::createDefaultTTNNOperandsWorkarounds(tensorInputs, + tensorResults); +} +} // namespace mlir::tt::ttnn::wa diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 24980fb7c0..e54377299e 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -62,6 +62,14 @@ void createTTNNPipelineLoweringPasses( pm.addPass(mlir::createRemoveDeadValuesPass()); } +// Create a pass to workaround issues in the TTNN dialect. +void createTTNNPipelineWorkaroundPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + if (options.workaroundPassEnabled) { + pm.addPass(createTTNNWorkarounds()); + } +} + void createTTNNPipelineLayoutDecompositionPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { pm.addPass(createTTNNDecomposeLayouts()); @@ -111,6 +119,7 @@ void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); createTTNNPipelineLoweringPasses(pm, options); + createTTNNPipelineWorkaroundPass(pm, options); createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); createTTNNPipelineDeallocPass(pm, options); diff --git a/lib/Dialect/TTNN/Transforms/CMakeLists.txt b/lib/Dialect/TTNN/Transforms/CMakeLists.txt index 3f075148b0..fd21e03d0c 100644 --- a/lib/Dialect/TTNN/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTNN/Transforms/CMakeLists.txt @@ -1,8 +1,9 @@ add_mlir_dialect_library(MLIRTTNNTransforms - TTNNLayout.cpp - Passes.cpp Optimizer.cpp + Passes.cpp + TTNNLayout.cpp TTNNToCpp.cpp + TTNNWorkarounds.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir diff --git a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp new file mode 100644 index 0000000000..3b7f14b87d --- /dev/null +++ b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp @@ -0,0 +1,434 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" +#include "ttmlir/Dialect/TTNN/Types/Types.h" +#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::tt::ttnn { +#define GEN_PASS_DEF_TTNNWORKAROUNDS +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" + +// Helper method to apply tensor layout workaround. It accepts workaround and +// current tensor layout as input arguments. It applies the workaround on a +// current tensor layout and returns true if the tensor layout was modified. +static bool applyTensorLayoutWorkaround(wa::TTNNOperandWorkarounds &workaround, + Layout &tensorLayout) { + bool modified = false; + if (workaround.getTensorLayoutWorkaround().getTensorLayoutWorkaround()) { + // Do something with tensor layout workaround + + modified = + tensorLayout != + *workaround.getTensorLayoutWorkaround().getTensorLayoutWorkaround(); + tensorLayout = + *workaround.getTensorLayoutWorkaround().getTensorLayoutWorkaround(); + llvm::outs() << "Layout workaround change: " << modified << "\n"; + } + + return modified; +} + +// Helper method to apply tensor buffer type workaround. It accepts workaround +// and current tensor buffer type as input arguments. It applies the workaround +// on a current tensor buffer type and returns true if the tensor buffer type +// was modified. +static bool +applyTensorBufferTypeWorkaround(wa::TTNNOperandWorkarounds &workaround, + BufferType &tensorBufferType) { + bool modified = false; + if (workaround.getTensorBufferTypeWorkaround() + .getTensorBufferTypeWorkaround()) { + // Do something with tensor memory space workaround + llvm::outs() << "Buffer type workaround change\n"; + modified = tensorBufferType != *workaround.getTensorBufferTypeWorkaround() + .getTensorBufferTypeWorkaround(); + tensorBufferType = *workaround.getTensorBufferTypeWorkaround() + .getTensorBufferTypeWorkaround(); + } + + return modified; +} + +// Helper method to apply tensor memory layout workaround. It accepts workaround +// and current tensor memory layout as input arguments. It applies the +// workaround on a current tensor memory layout and returns true if the tensor +// memory layout was modified. +static bool +applyTensorMemoryLayoutWorkaround(wa::TTNNOperandWorkarounds &workaround, + TensorMemoryLayout &tensorMemoryLayout) { + bool modified = false; + if (workaround.getTensorMemoryLayoutWorkaround() + .getTensorMemoryLayoutWorkaround()) { + // Do something with tensor memory layout workaround + llvm::outs() << "Memory layout workaround change\n"; + modified = + tensorMemoryLayout != *workaround.getTensorMemoryLayoutWorkaround() + .getTensorMemoryLayoutWorkaround(); + tensorMemoryLayout = *workaround.getTensorMemoryLayoutWorkaround() + .getTensorMemoryLayoutWorkaround(); + } + + return modified; +} + +// Helper method to propagate output changes into DPS inits. It accepts output +// result as input argument. It iterates over all the uses of the output result +// and updates the DPS init with the new output type. +static void propagateOutputChangesIntoDPSInits(OpResult &outputResult) { + // Iterate over all the uses of the outputResult + for (auto &use : outputResult.getUses()) { + // Get the user operation + Operation *userOp = use.getOwner(); + + // Check if the user operation is a DPS op + if (mlir::isa(userOp)) { + DestinationStyleOpInterface dpsOp = + mlir::dyn_cast(userOp); + + // Check if the use is a DPS init + if (dpsOp && dpsOp.isDpsInit(&use)) { + // Update the DPS init with the new output type + OperandRange dpsInits = dpsOp.getDpsInits(); + dpsOp + ->getResult(dpsInits.getBeginOperandIndex() - + use.getOperandNumber()) + .setType(outputResult.getType()); + } + } + } +} + +// Helper method to apply input operand workarounds. It accepts inputOperand, +// workaround, rewriter and current operation as input arguments. It applies the +// workarounds on the input operand and returns true if the workarounds were +// applied. +static bool workaroundInputOperand(OpOperand &inputOperand, + wa::TTNNOperandWorkarounds &workaround, + PatternRewriter &rewriter, + wa::TTNNWorkaroundInterface op) { + bool modified = false; + // Get the input operand type to extract the tensor layout, buffer type and + // memory layout + auto inputOperandType = + mlir::cast(inputOperand.get().getType()); + ::mlir::tt::ttnn::TTNNLayoutAttr inputLayoutAttr = + mlir::cast<::mlir::tt::ttnn::TTNNLayoutAttr>( + inputOperandType.getEncoding()); + Layout tensorLayout = + llvm::isa(inputLayoutAttr.getMemref().getElementType()) + ? Layout::Tile + : Layout::RowMajor; + BufferType tensorBufferType = inputLayoutAttr.getBufferType(); + TensorMemoryLayout tensorMemoryLayout = inputLayoutAttr.getMemLayout(); + + // Apply the workarounds on the input operand workadound arguments + modified |= applyTensorLayoutWorkaround(workaround, tensorLayout); + modified |= applyTensorBufferTypeWorkaround(workaround, tensorBufferType); + modified |= applyTensorMemoryLayoutWorkaround(workaround, tensorMemoryLayout); + + // If the modified flag is set, apply the workarounds on the input operand + // by inserting the ToLayoutOp with the desired tensor layout, buffer type + // and memory layout + if (modified) { + // Create the tensor layout attribute + LayoutAttr tensorLayoutAttr = + LayoutAttr::get(rewriter.getContext(), tensorLayout); + + // Create the data type attribute + DataType dtype = + ttnn::utils::getDataTypeFromMemRef(inputLayoutAttr.getMemref()); + DataTypeAttr dataTypeAttr = DataTypeAttr::get(rewriter.getContext(), dtype); + + // Create the output memory config attribute + ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get( + rewriter.getContext(), + ttnn::TensorMemoryLayoutAttr::get(rewriter.getContext(), + tensorMemoryLayout), + ttnn::BufferTypeAttr::get(rewriter.getContext(), tensorBufferType), + ttnn::ShardSpecAttr::get( + op.getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), + inputLayoutAttr.getMemref().getShape()))); + + // Create element type based on tensor layout + Type elementType = + tensorLayout == Layout::Tile + ? TileType::get( + rewriter.getContext(), {ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, + utils::getDataTypeFromMemRef(inputLayoutAttr.getMemref())) + : ttnn::utils::createRowMajorTypeFromDtype( + rewriter.getContext(), + utils::getDataTypeFromMemRef(inputLayoutAttr.getMemref())); + + // Insert a ToLayoutOp to convert the input operand to the desired + mlir::Value insertedToLayoutOpValue = + rewriter + .create( + op.getLoc(), + RankedTensorType::get( + inputOperandType.getShape(), + inputOperandType.getElementType(), + inputLayoutAttr + .withElementType(rewriter.getContext(), elementType) + .withBufferType(rewriter.getContext(), tensorBufferType) + .withMemoryLayout(rewriter.getContext(), + tensorMemoryLayout)), + inputOperand.get(), tensorLayoutAttr, dataTypeAttr, + outputMemConfigAttr, + (tensorBufferType == ttnn::BufferType::SystemMemory) + ? nullptr + : utils::getOrInsertDevice(rewriter, op)) + ->getResult(0); + + // Update the input operand with the new toLayout op operand + rewriter.modifyOpInPlace(op, [&]() { + op->setOperand(inputOperand.getOperandNumber(), insertedToLayoutOpValue); + + // If operand is a DPS init, update the result type on current op and + // propagate + DestinationStyleOpInterface dpsOp = + mlir::dyn_cast(op.getOperation()); + if (dpsOp && dpsOp.isDpsInit(&inputOperand)) { + // Get DPS inits and calculate the DPS result index + OperandRange dpsInits = dpsOp.getDpsInits(); + int dpsResultIndex = + dpsInits.getBeginOperandIndex() - inputOperand.getOperandNumber(); + + // Get the result of the DPS init and update its type + OpResult opResult = op->getResult(dpsResultIndex); + opResult.setType(insertedToLayoutOpValue.getType()); + + // Propagate output change into next DPS inits operands that uses this + // result + propagateOutputChangesIntoDPSInits(opResult); + } + }); + } + + return modified; +} + +// Helper method to apply output operand workarounds. It accepts outputResult, +// workaround, rewriter and current operation as input arguments. If the result +// is a DPS result, it only verifies that the output operand is the same as the +// coresponding DPS init. At this stage, it is expected that the DPS results are +// already propageted. If the result is not a DPS result, it applies the +// workarounds on the output operand and returns true if the workarounds were +// applied. It also propagates output changes into the further DPS inits. +static bool workaroundOutputOperand( + OpResult &outputResult, wa::TTNNOperandWorkarounds &outputWorkaround, + PatternRewriter &rewriter, wa::TTNNWorkaroundInterface op) { + bool modified = false; + + // Get the output result type to extract the tensor layout, buffer type and + // memory layout + RankedTensorType outputType = + mlir::cast(outputResult.getType()); + TTNNLayoutAttr layoutAttr = + mlir::cast(outputType.getEncoding()); + Layout tensorLayout = + llvm::isa(layoutAttr.getMemref().getElementType()) + ? Layout::Tile + : Layout::RowMajor; + BufferType tensorBufferType = layoutAttr.getBufferType(); + TensorMemoryLayout tensorMemoryLayout = layoutAttr.getMemLayout(); + + // Apply the workarounds on the output result workadound arguments + bool tensorLayoutChanged = + applyTensorLayoutWorkaround(outputWorkaround, tensorLayout); + bool tensorBufferTypeChanged = + applyTensorBufferTypeWorkaround(outputWorkaround, tensorBufferType); + bool tensorMemoryLayoutChanged = + applyTensorMemoryLayoutWorkaround(outputWorkaround, tensorMemoryLayout); + + modified = tensorLayoutChanged || tensorBufferTypeChanged || + tensorMemoryLayoutChanged; + // At this point, the DPS result should already be propagated, hence we only + // need to verify that the output workaround did not modify the output result + assert(!(modified && + mlir::isa(op.getOperation())) && + "Output operand workarounds not supported for DPS ops"); + + // If the modified flag is set, apply the workarounds on the output result + if (modified && !mlir::isa(op.getOperation())) { + // Create the tensor layout attribute + TTNNLayoutAttr outputLayout = + mlir::cast(outputType.getEncoding()); + + // Create the data type attribute + Type elementType = + tensorLayout == Layout::Tile + ? TileType::get( + rewriter.getContext(), {ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, + utils::getDataTypeFromMemRef(outputLayout.getMemref())) + : ttnn::utils::createRowMajorTypeFromDtype( + rewriter.getContext(), + utils::getDataTypeFromMemRef(outputLayout.getMemref())); + + // Create the new output result type with the updated tensor layout, buffer + // type and memory layout + RankedTensorType newOutputResultType = RankedTensorType::get( + outputType.getShape(), outputType.getElementType(), + outputLayout.withElementType(rewriter.getContext(), elementType) + .withBufferType(rewriter.getContext(), tensorBufferType) + .withMemoryLayout(rewriter.getContext(), tensorMemoryLayout)); + + // Update the type of result with applied workarounds + rewriter.modifyOpInPlace(op, [&]() { + outputResult.setType(newOutputResultType); + + // Some ops defines attributes with tensor layout, buffer type and memory + // layout, hence we need to update the attributes as well. For example, + // the empty op defines layout and memory_config attributes + if (tensorLayoutChanged && op->getAttrDictionary().get("layout")) { + LayoutAttr updatedLayoutAttr = + rewriter.getAttr(tensorLayout); + op->setAttr("layout", updatedLayoutAttr); + } + + if ((tensorBufferTypeChanged || tensorMemoryLayoutChanged) && + op->getAttrDictionary().get("memory_config")) { + // Create the output memory config attribute + ttnn::MemoryConfigAttr updatedMemConfigAttr = + ttnn::MemoryConfigAttr::get( + rewriter.getContext(), + ttnn::TensorMemoryLayoutAttr::get(rewriter.getContext(), + tensorMemoryLayout), + ttnn::BufferTypeAttr::get(rewriter.getContext(), + tensorBufferType), + ttnn::ShardSpecAttr::get( + op.getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), + outputLayout.getMemref().getShape()))); + op->setAttr("memory_config", updatedMemConfigAttr); + } + + // Propagate output change into next DPS inits operands that uses this + // result + propagateOutputChangesIntoDPSInits(outputResult); + }); + } + + return modified; +} + +// Rewriter to apply workarounds to the operands of TTNN operations. +// This rewriter applies the workarounds to the input and output operands of +// TTNN operations. +class TTNNOperandsWorkaroundsRewriter + : public OpInterfaceRewritePattern { +public: + TTNNOperandsWorkaroundsRewriter(MLIRContext *ctx) + : OpInterfaceRewritePattern(ctx) {} + + LogicalResult matchAndRewrite(wa::TTNNWorkaroundInterface op, + PatternRewriter &rewriter) const final { + + // To layout op is a special case, we don't want to rewrite it + if (mlir::isa(op.getOperation())) { + return failure(); + } + + // Get the operands workarounds for the current operation + wa::TTNNOperandsWorkarounds operandsWorkarounds = + op.getOperandsWorkarounds(); + + // Apply input workarounds only for tensor operands + bool modifiedOperands = false; + int input_operand_index = 0; + for (size_t i = 0; + i < operandsWorkarounds.getInputOperandWorkarounds().size(); i++) { + wa::TTNNOperandWorkarounds inputWorkaround = + operandsWorkarounds.getInputOperandWorkarounds()[i]; + + // No input operand workarounds to apply, hence continue + if (!inputWorkaround.hasAnyWorkaround()) { + input_operand_index++; + continue; + } + + // Get the next tensor opearand + while (!mlir::isa( + op->getOperand(input_operand_index).getType())) { + input_operand_index++; + } + + OpOperand &inputOperand = op->getOpOperand(input_operand_index++); + // Apply all workaround changes to the input operand + modifiedOperands |= + workaroundInputOperand(inputOperand, inputWorkaround, rewriter, op); + } + + // Apply output workarounds only for tensor operands + int output_operand_index = 0; + for (size_t i = 0; + i < operandsWorkarounds.getOutputOperandWorkarounds().size(); i++) { + wa::TTNNOperandWorkarounds outputWorkaround = + operandsWorkarounds.getOutputOperandWorkarounds()[i]; + + // No output operand workarounds to apply, hence continue + if (!outputWorkaround.hasAnyWorkaround()) { + output_operand_index++; + continue; + } + + // Get the next tensor result + while (!mlir::isa(op->getResult(i).getType())) { + output_operand_index++; + } + + OpResult outputResult = op->getResult(output_operand_index++); + // Apply all workaround changes to the output operand + modifiedOperands |= + workaroundOutputOperand(outputResult, outputWorkaround, rewriter, op); + } + + return modifiedOperands ? success() : failure(); + } +}; + +// Pass to apply workarounds to the operands of TTNN operations. +class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { +public: + using impl::TTNNWorkaroundsBase::TTNNWorkaroundsBase; + + void runOnOperation() final { + { + // Placeholder for workaround decomposition patterns + } + { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + FrozenRewritePatternSet patternSet(std::move(patterns)); + GreedyRewriteConfig config = GreedyRewriteConfig(); + config.useTopDownTraversal = true; + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet, + config))) { + signalPassFailure(); + return; + } + } + } +}; +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/CMakeLists.txt b/lib/Dialect/TTNN/Utils/CMakeLists.txt index f49f829e6f..52aa13025b 100644 --- a/lib/Dialect/TTNN/Utils/CMakeLists.txt +++ b/lib/Dialect/TTNN/Utils/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_dialect_library(TTMLIRTTNNUtils - Utils.cpp OptimizerOverrides.cpp + TransformUtils.cpp + Utils.cpp + ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/TTNN diff --git a/lib/Dialect/TTNN/Utils/TransformUtils.cpp b/lib/Dialect/TTNN/Utils/TransformUtils.cpp new file mode 100644 index 0000000000..44b01e91b3 --- /dev/null +++ b/lib/Dialect/TTNN/Utils/TransformUtils.cpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" + +namespace mlir::tt::ttnn::utils { +// Gets or inserts a GetDeviceOp at the top of the current block of the given +// operation. +Value getOrInsertDevice(PatternRewriter &rewriter, Operation *op) { + Block *block = op->getBlock(); + for (auto &op : block->getOperations()) { + if (auto deviceOp = dyn_cast(op)) { + return deviceOp.getResult(); + } + } + + DeviceAttr deviceAttr = getCurrentScopeDevice(op); + auto currentInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(block, block->begin()); + auto deviceOp = rewriter.create( + op->getLoc(), rewriter.getType(deviceAttr), + ttnn::MeshShapeAttr::get(op->getContext(), 1, 1)); + rewriter.restoreInsertionPoint(currentInsertionPoint); + return deviceOp.getResult(); +} +} // namespace mlir::tt::ttnn::utils diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir new file mode 100644 index 0000000000..9eed399840 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttnn-workaround %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, interleaved> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +module attributes {tt.device = #device} { + func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]] + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, #dram, <<2x4>>>}> : (tensor<64x128xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + // CHECK-NEXT: %[[RM_DEVICE_LAYOUT_OP:.*]] = "ttnn.to_layout"(%arg0, %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout1> + %2 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, #dram, <<64x128>>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<, #dram, <<64x128>>> + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout1> + %3 = "ttnn.abs"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[TO_LAYOUT_LEFT:.*]] = "ttnn.to_layout"(%[[RM_DEVICE_LAYOUT_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[TO_LAYOUT_RIGHT:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout2> + %4 = "ttnn.to_layout"(%3) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, #system_memory, <<64x128>>>}> : (tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout> + return %4 : tensor<64x128xf32, #ttnn_layout> + } +}