-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
87cdd07
commit 74aee6f
Showing
26 changed files
with
940 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDINTERFACE_H | ||
#define TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDINTERFACE_H | ||
|
||
#include "mlir/IR/Operation.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" | ||
|
||
namespace mlir::tt::ttnn::wa { | ||
// Verifies the TTNNWorkaroundInterface | ||
mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op); | ||
} // namespace mlir::tt::ttnn::wa | ||
|
||
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h.inc" | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD | ||
#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD | ||
|
||
include "mlir/IR/OpBase.td" | ||
|
||
// This interface is used to specify workarounds for TTNN operations. | ||
def TTNN_WorkaroundInterface : OpInterface<"TTNNWorkaroundInterface"> { | ||
let cppNamespace = "::mlir::tt::ttnn::wa"; | ||
let methods = [ | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Returns the workarounds associated with each operand and result of this operation. | ||
If the operation is a Destination-Passing Style (DPS) operation, the same workarounds | ||
must apply to both the DPS initial operands and the operation results. These constraints | ||
are verified through the interface verifier. | ||
|
||
For example, consider the following ttnn operations: | ||
%0 = "ttnn.empty"() : () -> tensor<1x1xf32> | ||
%1 = "ttnn.abs"(%arg0, %0) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> | ||
|
||
In this example: | ||
- The ttnn.abs operation has two input operand workarounds. | ||
- It has one output operand workaround. | ||
- The output workaround must match the workaround for the second input operand, | ||
ensuring consistency as required by the DPS pattern. | ||
}], | ||
/*retTy=*/"TTNNOperandsWorkarounds", | ||
/*methodName=*/"getOperandsWorkarounds", | ||
/*args=*/(ins), | ||
/*methodBody=*/"", | ||
/*defaultImplementation=*/[{ | ||
// Return default empty workarounds for all input and output operands | ||
return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(this->getOperation()); | ||
}] | ||
>, | ||
]; | ||
|
||
let verify = [{ | ||
return verifyTTNNWorkaroundInterface($_op); | ||
}]; | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDS_H | ||
#define TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDS_H | ||
|
||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" | ||
#include <llvm/ADT/ArrayRef.h> | ||
#include <llvm/ADT/SmallVector.h> | ||
#include <memory> | ||
#include <mlir/IR/Operation.h> | ||
#include <optional> | ||
#include <vector> | ||
|
||
namespace mlir::tt::ttnn::wa { | ||
using TensorLayoutWorkaround = std::optional<Layout>; | ||
using TensorBufferTypeWorkaround = std::optional<BufferType>; | ||
using TensorMemoryLayoutWorkaround = std::optional<TensorMemoryLayout>; | ||
|
||
// Class that encapsulates operand workarounds. | ||
// It contains tensor layout, tensor buffer type and tensor memory layout | ||
// workarounds. | ||
class TTNNOperandWorkarounds { | ||
public: | ||
// Default constructor with no workarounds. | ||
TTNNOperandWorkarounds() {} | ||
|
||
// Constructor that takes tensor layout, tensor buffer type and tensor memory | ||
TTNNOperandWorkarounds( | ||
TensorLayoutWorkaround tensorLayoutWorkaround, | ||
TensorBufferTypeWorkaround tensorBufferTypeWorkaround, | ||
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround) | ||
: tensorLayoutWorkaround(tensorLayoutWorkaround), | ||
tensorBufferTypeWorkaround(tensorBufferTypeWorkaround), | ||
tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround) {} | ||
|
||
// Operand workarounds factory methods | ||
static TTNNOperandWorkarounds createEmptyTTNNOperandWorkarounds(); | ||
|
||
// Returns a specific workaround. | ||
template <typename T> | ||
std::optional<T> getWorkaround() const { | ||
if constexpr (std::is_same<T, Layout>::value) { | ||
return tensorLayoutWorkaround; | ||
} else if constexpr (std::is_same<T, BufferType>::value) { | ||
return tensorBufferTypeWorkaround; | ||
} else if constexpr (std::is_same<T, TensorMemoryLayout>::value) { | ||
return tensorMemoryLayoutWorkaround; | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
// Sets a specific workaround. | ||
template <typename T> | ||
TTNNOperandWorkarounds &setWorkaround(T workaround) { | ||
if constexpr (std::is_same<T, Layout>::value) { | ||
tensorLayoutWorkaround = std::make_optional(workaround); | ||
} else if constexpr (std::is_same<T, BufferType>::value) { | ||
tensorBufferTypeWorkaround = std::make_optional(workaround); | ||
} else if constexpr (std::is_same<T, TensorMemoryLayout>::value) { | ||
tensorMemoryLayoutWorkaround = std::make_optional(workaround); | ||
} | ||
return *this; | ||
} | ||
|
||
// Equality operator. | ||
bool operator==(const TTNNOperandWorkarounds &rhs) const { | ||
return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround && | ||
tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround && | ||
tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround; | ||
} | ||
|
||
// Inequality operator. | ||
bool operator!=(const TTNNOperandWorkarounds &rhs) const { | ||
return !(*this == rhs); | ||
} | ||
|
||
// Returns true if any of the workarounds is set. | ||
bool hasAnyWorkaround() const { | ||
return tensorLayoutWorkaround || tensorBufferTypeWorkaround || | ||
tensorMemoryLayoutWorkaround; | ||
} | ||
|
||
private: | ||
// Tensor layout workaround | ||
TensorLayoutWorkaround tensorLayoutWorkaround; | ||
// Tensor buffer type workaround | ||
TensorBufferTypeWorkaround tensorBufferTypeWorkaround; | ||
// Tensor memory layout workaround. | ||
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround; | ||
}; | ||
|
||
// Class that encapsulates operands workarounds. | ||
// It contains input and output workarounds for operands. | ||
class TTNNOperandsWorkarounds { | ||
public: | ||
// Default constructor with no workarounds. | ||
TTNNOperandsWorkarounds() {} | ||
|
||
// Constructor that takes input and output workarounds for operands. | ||
TTNNOperandsWorkarounds( | ||
llvm::SmallVector<TTNNOperandWorkarounds> inputOperandWorkarounds, | ||
llvm::SmallVector<TTNNOperandWorkarounds> outputOperandWorkarounds) | ||
: inputOperandWorkarounds(std::move(inputOperandWorkarounds)), | ||
outputOperandWorkarounds(std::move(outputOperandWorkarounds)) {} | ||
|
||
// Returns input operand workarounds. | ||
llvm::ArrayRef<TTNNOperandWorkarounds> getInputOperandWorkarounds() const { | ||
return inputOperandWorkarounds; | ||
} | ||
|
||
// Returns output operand workarounds. | ||
llvm::ArrayRef<TTNNOperandWorkarounds> getOutputOperandWorkarounds() const { | ||
return outputOperandWorkarounds; | ||
} | ||
|
||
// Adds input operand workaround. | ||
TTNNOperandsWorkarounds & | ||
addInputOperandWorkaround(TTNNOperandWorkarounds inputOperandWorkaround) { | ||
inputOperandWorkarounds.emplace_back(inputOperandWorkaround); | ||
return *this; | ||
} | ||
|
||
// Adds output operand workaround. | ||
TTNNOperandsWorkarounds & | ||
addOutputOperandWorkaround(TTNNOperandWorkarounds outputOperandWorkaround) { | ||
outputOperandWorkarounds.emplace_back(outputOperandWorkaround); | ||
return *this; | ||
} | ||
|
||
// Operands workarounds factory method | ||
static TTNNOperandsWorkarounds | ||
createEmptyTTNNOperandsWorkarounds(int inputSize, int outputSize); | ||
|
||
// Operands workarounds factory method | ||
static TTNNOperandsWorkarounds createEmptyTTNNOperandsWorkarounds() { | ||
return createEmptyTTNNOperandsWorkarounds(0, 0); | ||
} | ||
|
||
// Operands workarounds factory method | ||
static TTNNOperandsWorkarounds | ||
createEmptyTTNNOperandsWorkarounds(Operation *op); | ||
|
||
private: | ||
// Workarounds for input operands. | ||
llvm::SmallVector<TTNNOperandWorkarounds> inputOperandWorkarounds; | ||
// Workarounds for output operands. | ||
llvm::SmallVector<TTNNOperandWorkarounds> outputOperandWorkarounds; | ||
}; | ||
|
||
} // namespace mlir::tt::ttnn::wa | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H | ||
#define TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H | ||
|
||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/Value.h" | ||
|
||
namespace mlir::tt::ttnn::utils { | ||
// Get or insert device for the given operation. | ||
mlir::Value getOrInsertDevice(mlir::PatternRewriter &rewriter, | ||
mlir::Operation *op); | ||
} // namespace mlir::tt::ttnn::utils | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.