Skip to content

Commit

Permalink
Adding workaround op interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Nov 29, 2024
1 parent 87cdd07 commit 74aee6f
Show file tree
Hide file tree
Showing 26 changed files with 940 additions and 35 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

add_mlir_interface(TTNNOpModelInterface)
add_mlir_interface(TTNNWorkaroundInterface)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

include "mlir/IR/OpBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td"
include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td"

//===----------------------------------------------------------------------===//
// TTNN dialect definition.
Expand Down Expand Up @@ -44,6 +45,6 @@ def TTNN_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNN_OpModelInterface])>;
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNN_OpModelInterface, TTNN_WorkaroundInterface])>;

#endif
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h"

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h.inc"
Expand Down
29 changes: 27 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
let description = [{
Eltwise absolute operation.
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(
wa::TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds()
.setWorkaround(Layout::Tile))
.addInputOperandWorkaround(
wa::TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds()
.setWorkaround(Layout::Tile))
.addOutputOperandWorkaround(
wa::TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds()
.setWorkaround(Layout::Tile));
}
}];
}

def TTNN_CbrtOp : TTNN_ElementwiseUnaryOp<"cbrt"> {
Expand Down Expand Up @@ -325,7 +341,7 @@ def TTNN_Expm1Op: TTNN_ElementwiseUnaryOp<"expm1"> {
}];
}

class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> traits = []> :
class TTNN_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseUnaryOp<mnemonic, traits> {
let summary = "Eltwise unary op with the float parameter.";
let description = [{
Expand All @@ -345,7 +361,7 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> tra
];
}

def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
def TTNN_LeakyReluOp : TTNN_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
let summary = "Eltwise leaky relu operation.";
let description = [{
The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise
Expand Down Expand Up @@ -784,6 +800,15 @@ def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> {
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addOutputOperandWorkaround(
wa::TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds()
.setWorkaround(Layout::RowMajor));
}
}];

let hasVerifier = 1;
}

Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
{
return this->getShardSpec().getShardShape().getShape();
}

MemoryConfigAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType);
MemoryConfigAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
}];
}

Expand Down
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDINTERFACE_H
#define TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDINTERFACE_H

#include "mlir/IR/Operation.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h"

namespace mlir::tt::ttnn::wa {
// Verifies the TTNNWorkaroundInterface
mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op);
} // namespace mlir::tt::ttnn::wa

#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h.inc"

#endif
47 changes: 47 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD
#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD

include "mlir/IR/OpBase.td"

// This interface is used to specify workarounds for TTNN operations.
def TTNN_WorkaroundInterface : OpInterface<"TTNNWorkaroundInterface"> {
let cppNamespace = "::mlir::tt::ttnn::wa";
let methods = [
InterfaceMethod<
/*desc=*/[{
Returns the workarounds associated with each operand and result of this operation.
If the operation is a Destination-Passing Style (DPS) operation, the same workarounds
must apply to both the DPS initial operands and the operation results. These constraints
are verified through the interface verifier.

For example, consider the following ttnn operations:
%0 = "ttnn.empty"() : () -> tensor<1x1xf32>
%1 = "ttnn.abs"(%arg0, %0) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>

In this example:
- The ttnn.abs operation has two input operand workarounds.
- It has one output operand workaround.
- The output workaround must match the workaround for the second input operand,
ensuring consistency as required by the DPS pattern.
}],
/*retTy=*/"TTNNOperandsWorkarounds",
/*methodName=*/"getOperandsWorkarounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Return default empty workarounds for all input and output operands
return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(this->getOperation());
}]
>,
];

let verify = [{
return verifyTTNNWorkaroundInterface($_op);
}];
}

#endif
154 changes: 154 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDS_H
#define TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDS_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <memory>
#include <mlir/IR/Operation.h>
#include <optional>
#include <vector>

namespace mlir::tt::ttnn::wa {
using TensorLayoutWorkaround = std::optional<Layout>;
using TensorBufferTypeWorkaround = std::optional<BufferType>;
using TensorMemoryLayoutWorkaround = std::optional<TensorMemoryLayout>;

// Class that encapsulates operand workarounds.
// It contains tensor layout, tensor buffer type and tensor memory layout
// workarounds.
class TTNNOperandWorkarounds {
public:
// Default constructor with no workarounds.
TTNNOperandWorkarounds() {}

// Constructor that takes tensor layout, tensor buffer type and tensor memory
TTNNOperandWorkarounds(
TensorLayoutWorkaround tensorLayoutWorkaround,
TensorBufferTypeWorkaround tensorBufferTypeWorkaround,
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround)
: tensorLayoutWorkaround(tensorLayoutWorkaround),
tensorBufferTypeWorkaround(tensorBufferTypeWorkaround),
tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround) {}

// Operand workarounds factory methods
static TTNNOperandWorkarounds createEmptyTTNNOperandWorkarounds();

// Returns a specific workaround.
template <typename T>
std::optional<T> getWorkaround() const {
if constexpr (std::is_same<T, Layout>::value) {
return tensorLayoutWorkaround;
} else if constexpr (std::is_same<T, BufferType>::value) {
return tensorBufferTypeWorkaround;
} else if constexpr (std::is_same<T, TensorMemoryLayout>::value) {
return tensorMemoryLayoutWorkaround;
}
return std::nullopt;
}

// Sets a specific workaround.
template <typename T>
TTNNOperandWorkarounds &setWorkaround(T workaround) {
if constexpr (std::is_same<T, Layout>::value) {
tensorLayoutWorkaround = std::make_optional(workaround);
} else if constexpr (std::is_same<T, BufferType>::value) {
tensorBufferTypeWorkaround = std::make_optional(workaround);
} else if constexpr (std::is_same<T, TensorMemoryLayout>::value) {
tensorMemoryLayoutWorkaround = std::make_optional(workaround);
}
return *this;
}

// Equality operator.
bool operator==(const TTNNOperandWorkarounds &rhs) const {
return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround &&
tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround &&
tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround;
}

// Inequality operator.
bool operator!=(const TTNNOperandWorkarounds &rhs) const {
return !(*this == rhs);
}

// Returns true if any of the workarounds is set.
bool hasAnyWorkaround() const {
return tensorLayoutWorkaround || tensorBufferTypeWorkaround ||
tensorMemoryLayoutWorkaround;
}

private:
// Tensor layout workaround
TensorLayoutWorkaround tensorLayoutWorkaround;
// Tensor buffer type workaround
TensorBufferTypeWorkaround tensorBufferTypeWorkaround;
// Tensor memory layout workaround.
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround;
};

// Class that encapsulates operands workarounds.
// It contains input and output workarounds for operands.
class TTNNOperandsWorkarounds {
public:
// Default constructor with no workarounds.
TTNNOperandsWorkarounds() {}

// Constructor that takes input and output workarounds for operands.
TTNNOperandsWorkarounds(
llvm::SmallVector<TTNNOperandWorkarounds> inputOperandWorkarounds,
llvm::SmallVector<TTNNOperandWorkarounds> outputOperandWorkarounds)
: inputOperandWorkarounds(std::move(inputOperandWorkarounds)),
outputOperandWorkarounds(std::move(outputOperandWorkarounds)) {}

// Returns input operand workarounds.
llvm::ArrayRef<TTNNOperandWorkarounds> getInputOperandWorkarounds() const {
return inputOperandWorkarounds;
}

// Returns output operand workarounds.
llvm::ArrayRef<TTNNOperandWorkarounds> getOutputOperandWorkarounds() const {
return outputOperandWorkarounds;
}

// Adds input operand workaround.
TTNNOperandsWorkarounds &
addInputOperandWorkaround(TTNNOperandWorkarounds inputOperandWorkaround) {
inputOperandWorkarounds.emplace_back(inputOperandWorkaround);
return *this;
}

// Adds output operand workaround.
TTNNOperandsWorkarounds &
addOutputOperandWorkaround(TTNNOperandWorkarounds outputOperandWorkaround) {
outputOperandWorkarounds.emplace_back(outputOperandWorkaround);
return *this;
}

// Operands workarounds factory method
static TTNNOperandsWorkarounds
createEmptyTTNNOperandsWorkarounds(int inputSize, int outputSize);

// Operands workarounds factory method
static TTNNOperandsWorkarounds createEmptyTTNNOperandsWorkarounds() {
return createEmptyTTNNOperandsWorkarounds(0, 0);
}

// Operands workarounds factory method
static TTNNOperandsWorkarounds
createEmptyTTNNOperandsWorkarounds(Operation *op);

private:
// Workarounds for input operands.
llvm::SmallVector<TTNNOperandWorkarounds> inputOperandWorkarounds;
// Workarounds for output operands.
llvm::SmallVector<TTNNOperandWorkarounds> outputOperandWorkarounds;
};

} // namespace mlir::tt::ttnn::wa

#endif
5 changes: 5 additions & 0 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ struct TTIRToTTNNBackendPipelineOptions

ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};

// Option to enable/disable the workaround pass
Option<bool> workaroundPassEnabled{*this, "enable-workaround-pass",
llvm::cl::desc("Enable workaround pass."),
llvm::cl::init(false)};
};

// TTIR to EmitC pipeline options.
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,12 @@ def TTNNLayout : Pass<"ttnn-layout", "::mlir::ModuleOp"> {
}];
}

def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
let summary = "Apply workarounds to the IR.";
let description = [{
This pass applies necessary TTNN workarounds to the IR in order to create
a valid and functional IR that can be executed on the hardware.
}];
}

#endif
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H
#define TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H

#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"

namespace mlir::tt::ttnn::utils {
// Get or insert device for the given operation.
mlir::Value getOrInsertDevice(mlir::PatternRewriter &rewriter,
mlir::Operation *op);
} // namespace mlir::tt::ttnn::utils

#endif
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
#define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H

#include "mlir/IR/BuiltinTypes.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
Expand Down Expand Up @@ -34,6 +35,11 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType);
mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

// Helper method to create a RankedTensorType with the given encoding
RankedTensorType
createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
ttnn::TTNNLayoutAttr encoding);

} // namespace mlir::tt::ttnn::utils

#endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
5 changes: 5 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ inline MlirAttribute wrapArrayOfMlirAttributesAsAttribute(
return wrap(mlir::ArrayAttr::get(unwrap(ctx), unwrappedAttributesArray));
}

// Checks if the given `mlir::Type` is a ranked tensor type.
inline bool isRankedTensorType(mlir::Type type) {
return mlir::isa<mlir::RankedTensorType>(type);
}

} // namespace ttmlir::utils

#endif
Loading

0 comments on commit 74aee6f

Please sign in to comment.