forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move common RNN utilities into src/Conversion/ONNXConversionCommon
Signed-off-by: Yan Xu <[email protected]>
- Loading branch information
Showing
14 changed files
with
361 additions
and
482 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
add_onnx_mlir_library(OMONNXConversionCommon | ||
RNN/RNNBase.cpp | ||
RNN/LSTM.cpp | ||
|
||
LINK_LIBS PUBLIC | ||
|
||
OMONNXOps | ||
OMSupport | ||
MLIRTransforms | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===// | ||
// | ||
// Copyright 2019-2024 The IBM Research Authors. | ||
// Modifications Copyright 2023-2024 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file includes utilities for lowering the ONNX LSTM Operators. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXConversionCommon/RNN/LSTM.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
template <> | ||
bool hasAllNoneOutput<ONNXLSTMOp>(ONNXLSTMOp *op) { | ||
return (isNoneValue(op->getY()) && isNoneValue(op->getYH()) && | ||
isNoneValue(op->getYC())); | ||
} | ||
|
||
template <> | ||
std::tuple<LstmActivationPack, LstmActivationPack> | ||
getActivationPack<ONNXLSTMOp, LstmActivationPack>(ONNXLSTMOp *op) { | ||
auto direction = op->getDirection(); | ||
auto activations = op->getActivations(); | ||
auto activationAlpha = op->getActivationAlpha(); | ||
auto activationBeta = op->getActivationBeta(); | ||
|
||
LstmActivationPack activationForward, activationReverse; | ||
|
||
// Get activation function name. | ||
// Default forward functions | ||
activationForward.f.name = "sigmoid"; | ||
activationForward.g.name = "tanh"; | ||
activationForward.h.name = "tanh"; | ||
// Default backward functions | ||
activationReverse.f.name = "sigmoid"; | ||
activationReverse.g.name = "tanh"; | ||
activationReverse.h.name = "tanh"; | ||
if (activations) { | ||
ArrayAttr activationArrAttr = activations.value(); | ||
if (direction == FORWARD || direction == BIDIRECTIONAL) { | ||
// Forward activations. | ||
if (activationArrAttr.size() > 0) { | ||
activationForward.f.name = | ||
activationArrAttr[0].cast<StringAttr>().getValue(); | ||
} | ||
if (activationArrAttr.size() > 1) { | ||
activationForward.g.name = | ||
activationArrAttr[1].cast<StringAttr>().getValue(); | ||
} | ||
if (activationArrAttr.size() > 2) { | ||
activationForward.h.name = | ||
activationArrAttr[2].cast<StringAttr>().getValue(); | ||
} | ||
} | ||
|
||
// Reverse activations. | ||
if (direction == REVERSE || direction == BIDIRECTIONAL) { | ||
unsigned int startIndex = (direction == REVERSE) ? 0 : 3; | ||
if (activationArrAttr.size() > startIndex) { | ||
activationReverse.f.name = | ||
activationArrAttr[startIndex].cast<StringAttr>().getValue(); | ||
} | ||
if (activationArrAttr.size() > startIndex + 1) { | ||
activationReverse.g.name = | ||
activationArrAttr[startIndex + 1].cast<StringAttr>().getValue(); | ||
} | ||
if (activationArrAttr.size() > startIndex + 2) { | ||
activationReverse.h.name = | ||
activationArrAttr[startIndex + 2].cast<StringAttr>().getValue(); | ||
} | ||
} | ||
} | ||
|
||
// Get alpha attributes. | ||
if (activationAlpha) { | ||
ArrayAttr activationArrAttr = activationAlpha.value(); | ||
if (direction == FORWARD || direction == BIDIRECTIONAL) { | ||
// Forward activations. | ||
if (activationArrAttr.size() > 0) { | ||
activationForward.f.alpha = activationArrAttr[0].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > 1) { | ||
activationForward.g.alpha = activationArrAttr[1].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > 2) { | ||
activationForward.h.alpha = activationArrAttr[2].cast<FloatAttr>(); | ||
} | ||
} | ||
|
||
// Reverse activations. | ||
if (direction == REVERSE || direction == BIDIRECTIONAL) { | ||
unsigned int startIndex = (direction == REVERSE) ? 0 : 3; | ||
if (activationArrAttr.size() > startIndex) { | ||
activationReverse.f.alpha = | ||
activationArrAttr[startIndex].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > startIndex + 1) { | ||
activationReverse.g.alpha = | ||
activationArrAttr[startIndex + 1].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > startIndex + 2) { | ||
activationReverse.h.alpha = | ||
activationArrAttr[startIndex + 2].cast<FloatAttr>(); | ||
} | ||
} | ||
} | ||
|
||
// Get beta attributes. | ||
if (activationBeta) { | ||
ArrayAttr activationArrAttr = activationBeta.value(); | ||
if (direction == FORWARD || direction == BIDIRECTIONAL) { | ||
// Forward activations. | ||
if (activationArrAttr.size() > 0) { | ||
activationForward.f.beta = activationArrAttr[0].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > 1) { | ||
activationForward.g.beta = activationArrAttr[1].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > 2) { | ||
activationForward.h.beta = activationArrAttr[2].cast<FloatAttr>(); | ||
} | ||
} | ||
|
||
// Reverse activations. | ||
if (direction == REVERSE || direction == BIDIRECTIONAL) { | ||
unsigned int startIndex = (direction == REVERSE) ? 0 : 3; | ||
if (activationArrAttr.size() > startIndex) { | ||
activationReverse.f.beta = | ||
activationArrAttr[startIndex].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > startIndex + 1) { | ||
activationReverse.g.beta = | ||
activationArrAttr[startIndex + 1].cast<FloatAttr>(); | ||
} | ||
if (activationArrAttr.size() > startIndex + 2) { | ||
activationReverse.h.beta = | ||
activationArrAttr[startIndex + 2].cast<FloatAttr>(); | ||
} | ||
} | ||
} | ||
|
||
return std::make_tuple(activationForward, activationReverse); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===--------------- LSTM.hpp - Lowering LSTM Op --------------------------===// | ||
// | ||
// Copyright 2024 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file includes utilities for lowering the ONNX LSTM Operators. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp" | ||
|
||
namespace onnx_mlir { | ||
|
||
struct LstmActivationPack { | ||
RNNActivation f; | ||
RNNActivation g; | ||
RNNActivation h; | ||
}; | ||
|
||
struct LstmWeightPack { | ||
mlir::Value WT; | ||
mlir::Value RT; | ||
}; | ||
|
||
struct LstmBiasPack { | ||
bool hasBias = false; | ||
mlir::Value Wbi; | ||
mlir::Value Wbo; | ||
mlir::Value Wbf; | ||
mlir::Value Wbc; | ||
mlir::Value Rbi; | ||
mlir::Value Rbo; | ||
mlir::Value Rbf; | ||
mlir::Value Rbc; | ||
// Put peephole here. | ||
bool hasPeephole = false; | ||
mlir::Value Pi; | ||
mlir::Value Po; | ||
mlir::Value Pf; | ||
}; | ||
|
||
template <> | ||
bool hasAllNoneOutput<mlir::ONNXLSTMOp>(mlir::ONNXLSTMOp *op); | ||
|
||
template <> | ||
std::tuple<LstmActivationPack, LstmActivationPack> | ||
getActivationPack<mlir::ONNXLSTMOp, LstmActivationPack>(mlir::ONNXLSTMOp *op); | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===// | ||
// | ||
// Copyright 2019-2024 The IBM Research Authors. | ||
// Modifications Copyright 2023-2024 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file defines common base utilities for lowering the ONNX RNN Operators. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
// Get a dimension of the tensor's shape. | ||
int64_t dimAt(Value val, int index) { | ||
return val.getType().cast<ShapedType>().getShape()[index]; | ||
} | ||
|
||
// Apply an activation function on a given scalar operand. | ||
Value applyActivation(OpBuilder &rewriter, Location loc, | ||
RNNActivation activation, Value operand) { | ||
Value res; | ||
|
||
std::vector<mlir::NamedAttribute> attributes; | ||
if (activation.alpha) { | ||
attributes.emplace_back( | ||
rewriter.getNamedAttr("alpha", activation.alpha.value())); | ||
} | ||
if (activation.beta) { | ||
attributes.emplace_back( | ||
rewriter.getNamedAttr("beta", activation.beta.value())); | ||
} | ||
Type resType = operand.getType(); | ||
|
||
// Change equality to be case insensitive. | ||
if (activation.name.equals_insensitive("relu")) | ||
res = rewriter.create<ONNXReluOp>(loc, resType, operand); | ||
else if (activation.name.equals_insensitive("tanh")) | ||
res = rewriter.create<ONNXTanhOp>(loc, resType, operand); | ||
else if (activation.name.equals_insensitive("sigmoid")) | ||
res = rewriter.create<ONNXSigmoidOp>(loc, resType, operand); | ||
else if (activation.name.equals_insensitive("affine")) | ||
llvm_unreachable("Unsupported activation"); | ||
else if (activation.name.equals_insensitive("leakyrelu")) | ||
res = rewriter.create<ONNXLeakyReluOp>(loc, resType, operand, attributes); | ||
else if (activation.name.equals_insensitive("thresholdedrelu")) | ||
res = rewriter.create<ONNXThresholdedReluOp>( | ||
loc, resType, operand, attributes); | ||
else if (activation.name.equals_insensitive("scaledtanh")) | ||
llvm_unreachable("Unsupported activation"); | ||
else if (activation.name.equals_insensitive("hardsigmoid")) | ||
res = rewriter.create<ONNXHardSigmoidOp>(loc, resType, operand, attributes); | ||
else if (activation.name.equals_insensitive("elu")) | ||
res = rewriter.create<ONNXEluOp>(loc, resType, operand, attributes); | ||
else if (activation.name.equals_insensitive("softsign")) | ||
res = rewriter.create<ONNXSoftsignOp>(loc, resType, operand); | ||
else if (activation.name.equals_insensitive("softplus")) | ||
res = rewriter.create<ONNXSoftplusOp>(loc, resType, operand); | ||
else | ||
llvm_unreachable("Unsupported activation"); | ||
|
||
return res; | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===// | ||
// | ||
// Copyright 2019-2024 The IBM Research Authors. | ||
// Modifications Copyright 2023-2024 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file defines common base utilities for lowering the ONNX RNN Operators. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
#include "src/Dialect/ONNX/ONNXOps.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" | ||
|
||
static constexpr llvm::StringRef FORWARD = "forward"; | ||
static constexpr llvm::StringRef REVERSE = "reverse"; | ||
static constexpr llvm::StringRef BIDIRECTIONAL = "bidirectional"; | ||
|
||
namespace onnx_mlir { | ||
|
||
struct RNNActivation { | ||
llvm::StringRef name; | ||
std::optional<mlir::FloatAttr> alpha; | ||
std::optional<mlir::FloatAttr> beta; | ||
}; | ||
|
||
/// Get a dimension of the tensor's shape. | ||
int64_t dimAt(mlir::Value val, int index); | ||
|
||
/// Apply an activation function on a given operand. | ||
mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc, | ||
RNNActivation activation, mlir::Value operand); | ||
|
||
// Override the following methods when lowering an RNN operation: | ||
// - hasAllNoneOutput | ||
// - getActivationPack | ||
|
||
// Check whether all outputs have NoneType or not. | ||
template <typename RNNOp> | ||
bool hasAllNoneOutput(RNNOp *op); | ||
|
||
// Obtain activations functions for a specific operation. | ||
template <typename RNNOp, typename A> | ||
std::tuple<A, A> getActivationPack(RNNOp *op); | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.