Skip to content

Commit

Permalink
Adding workaround op interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Nov 26, 2024
1 parent 9d731b7 commit 8220bee
Show file tree
Hide file tree
Showing 19 changed files with 1,104 additions and 34 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

add_mlir_interface(TTNNOpModelInterface)
add_mlir_interface(TTNNWorkaroundInterface)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

include "mlir/IR/OpBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td"
include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td"

//===----------------------------------------------------------------------===//
// TTNN dialect definition.
Expand Down Expand Up @@ -44,6 +45,6 @@ def TTNN_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNN_OpModelInterface])>;
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNN_OpModelInterface, TTNN_WorkaroundInterface])>;

#endif
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h"

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h.inc"
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def TTNN_Expm1Op: TTNN_ElementwiseUnaryOp<"expm1"> {
}];
}

class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> traits = []> :
class TTNN_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseUnaryOp<mnemonic, traits> {
let summary = "Eltwise unary op with the float parameter.";
let description = [{
Expand All @@ -345,7 +345,7 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> tra
];
}

def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
def TTNN_LeakyReluOp : TTNN_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
let summary = "Eltwise leaky relu operation.";
let description = [{
The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise
Expand Down
22 changes: 22 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_H
#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_H

#include "mlir/IR/Operation.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h"

namespace mlir::tt::ttnn::wa {
// Gets default operand workarounds for the given operation. This method is
// called from the TTNNWorkaroundInterface and its getOperandsWorkarounds
// method.
TTNNOperandsWorkarounds getDefaultOperandWorkarounds(Operation *op);

// Verifies the TTNNWorkaroundInterface
mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op);
} // namespace mlir::tt::ttnn::wa

#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h.inc"

#endif
46 changes: 46 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD
#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD

include "mlir/IR/OpBase.td"

// This interface is used to specify workarounds for TTNN operations.
def TTNN_WorkaroundInterface : OpInterface<"TTNNWorkaroundInterface"> {
let cppNamespace = "::mlir::tt::ttnn::wa";
let methods = [
InterfaceMethod<
/*desc=*/[{
Returns the workarounds associated with each operand and result of this operation.
If the operation is a Destination-Passing Style (DPS) operation, the same workarounds
must apply to both the DPS initial operands and the operation results. These constraints
are verified through the interface verifier.

For example, consider the following ttnn operations:
%0 = "ttnn.emptyOp"() : () -> tensor<1x1xf32>
%1 = "ttnn.abs"(%arg0, %0) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>

In this example:
- The ttnn.abs operation has two input operand workarounds.
- It has one output operand workaround.
- The output workaround must match the workaround for the second input operand,
ensuring consistency as required by the DPS pattern.
}],
/*retTy=*/"TTNNOperandsWorkarounds",
/*methodName=*/"getOperandsWorkarounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getDefaultOperandWorkarounds(this->getOperation());
}]
>,
];

let verify = [{
return verifyTTNNWorkaroundInterface($_op);
}];
}

#endif
Loading

0 comments on commit 8220bee

Please sign in to comment.