Skip to content

Commit

Permalink
move common RNN utilities into src/Conversion/ONNXConversionCommon
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Xu <[email protected]>
  • Loading branch information
Connor-XY committed Jan 18, 2024
1 parent 3f4797d commit 9ec2829
Show file tree
Hide file tree
Showing 14 changed files with 361 additions and 482 deletions.
1 change: 1 addition & 0 deletions src/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

add_subdirectory(ONNXConversionCommon)
add_subdirectory(ONNXToKrnl)
add_subdirectory(KrnlToLLVM)
add_subdirectory(KrnlToAffine)
Expand Down
12 changes: 12 additions & 0 deletions src/Conversion/ONNXConversionCommon/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

add_onnx_mlir_library(OMONNXConversionCommon
RNN/RNNBase.cpp
RNN/LSTM.cpp

LINK_LIBS PUBLIC

OMONNXOps
OMSupport
MLIRTransforms
)
154 changes: 154 additions & 0 deletions src/Conversion/ONNXConversionCommon/RNN/LSTM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===//
//
// Copyright 2019-2024 The IBM Research Authors.
// Modifications Copyright 2023-2024
//
// =============================================================================
//
// This file includes utilities for lowering the ONNX LSTM Operators.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXConversionCommon/RNN/LSTM.hpp"

using namespace mlir;

namespace onnx_mlir {

template <>
bool hasAllNoneOutput<ONNXLSTMOp>(ONNXLSTMOp *op) {
return (isNoneValue(op->getY()) && isNoneValue(op->getYH()) &&
isNoneValue(op->getYC()));
}

template <>
std::tuple<LstmActivationPack, LstmActivationPack>
getActivationPack<ONNXLSTMOp, LstmActivationPack>(ONNXLSTMOp *op) {
auto direction = op->getDirection();
auto activations = op->getActivations();
auto activationAlpha = op->getActivationAlpha();
auto activationBeta = op->getActivationBeta();

LstmActivationPack activationForward, activationReverse;

// Get activation function name.
// Default forward functions
activationForward.f.name = "sigmoid";
activationForward.g.name = "tanh";
activationForward.h.name = "tanh";
// Default backward functions
activationReverse.f.name = "sigmoid";
activationReverse.g.name = "tanh";
activationReverse.h.name = "tanh";
if (activations) {
ArrayAttr activationArrAttr = activations.value();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
// Forward activations.
if (activationArrAttr.size() > 0) {
activationForward.f.name =
activationArrAttr[0].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > 1) {
activationForward.g.name =
activationArrAttr[1].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > 2) {
activationForward.h.name =
activationArrAttr[2].cast<StringAttr>().getValue();
}
}

// Reverse activations.
if (direction == REVERSE || direction == BIDIRECTIONAL) {
unsigned int startIndex = (direction == REVERSE) ? 0 : 3;
if (activationArrAttr.size() > startIndex) {
activationReverse.f.name =
activationArrAttr[startIndex].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > startIndex + 1) {
activationReverse.g.name =
activationArrAttr[startIndex + 1].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > startIndex + 2) {
activationReverse.h.name =
activationArrAttr[startIndex + 2].cast<StringAttr>().getValue();
}
}
}

// Get alpha attributes.
if (activationAlpha) {
ArrayAttr activationArrAttr = activationAlpha.value();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
// Forward activations.
if (activationArrAttr.size() > 0) {
activationForward.f.alpha = activationArrAttr[0].cast<FloatAttr>();
}
if (activationArrAttr.size() > 1) {
activationForward.g.alpha = activationArrAttr[1].cast<FloatAttr>();
}
if (activationArrAttr.size() > 2) {
activationForward.h.alpha = activationArrAttr[2].cast<FloatAttr>();
}
}

// Reverse activations.
if (direction == REVERSE || direction == BIDIRECTIONAL) {
unsigned int startIndex = (direction == REVERSE) ? 0 : 3;
if (activationArrAttr.size() > startIndex) {
activationReverse.f.alpha =
activationArrAttr[startIndex].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 1) {
activationReverse.g.alpha =
activationArrAttr[startIndex + 1].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 2) {
activationReverse.h.alpha =
activationArrAttr[startIndex + 2].cast<FloatAttr>();
}
}
}

// Get beta attributes.
if (activationBeta) {
ArrayAttr activationArrAttr = activationBeta.value();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
// Forward activations.
if (activationArrAttr.size() > 0) {
activationForward.f.beta = activationArrAttr[0].cast<FloatAttr>();
}
if (activationArrAttr.size() > 1) {
activationForward.g.beta = activationArrAttr[1].cast<FloatAttr>();
}
if (activationArrAttr.size() > 2) {
activationForward.h.beta = activationArrAttr[2].cast<FloatAttr>();
}
}

// Reverse activations.
if (direction == REVERSE || direction == BIDIRECTIONAL) {
unsigned int startIndex = (direction == REVERSE) ? 0 : 3;
if (activationArrAttr.size() > startIndex) {
activationReverse.f.beta =
activationArrAttr[startIndex].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 1) {
activationReverse.g.beta =
activationArrAttr[startIndex + 1].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 2) {
activationReverse.h.beta =
activationArrAttr[startIndex + 2].cast<FloatAttr>();
}
}
}

return std::make_tuple(activationForward, activationReverse);
}

} // namespace onnx_mlir
54 changes: 54 additions & 0 deletions src/Conversion/ONNXConversionCommon/RNN/LSTM.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------- LSTM.hpp - Lowering LSTM Op --------------------------===//
//
// Copyright 2024
//
// =============================================================================
//
// This file includes utilities for lowering the ONNX LSTM Operators.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp"

namespace onnx_mlir {

struct LstmActivationPack {
RNNActivation f;
RNNActivation g;
RNNActivation h;
};

struct LstmWeightPack {
mlir::Value WT;
mlir::Value RT;
};

struct LstmBiasPack {
bool hasBias = false;
mlir::Value Wbi;
mlir::Value Wbo;
mlir::Value Wbf;
mlir::Value Wbc;
mlir::Value Rbi;
mlir::Value Rbo;
mlir::Value Rbf;
mlir::Value Rbc;
// Put peephole here.
bool hasPeephole = false;
mlir::Value Pi;
mlir::Value Po;
mlir::Value Pf;
};

template <>
bool hasAllNoneOutput<mlir::ONNXLSTMOp>(mlir::ONNXLSTMOp *op);

template <>
std::tuple<LstmActivationPack, LstmActivationPack>
getActivationPack<mlir::ONNXLSTMOp, LstmActivationPack>(mlir::ONNXLSTMOp *op);

} // namespace onnx_mlir
73 changes: 73 additions & 0 deletions src/Conversion/ONNXConversionCommon/RNN/RNNBase.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===//
//
// Copyright 2019-2024 The IBM Research Authors.
// Modifications Copyright 2023-2024
//
// =============================================================================
//
// This file defines common base utilities for lowering the ONNX RNN Operators.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp"

using namespace mlir;

namespace onnx_mlir {

// Get a dimension of the tensor's shape.
int64_t dimAt(Value val, int index) {
return val.getType().cast<ShapedType>().getShape()[index];
}

// Apply an activation function on a given scalar operand.
Value applyActivation(OpBuilder &rewriter, Location loc,
RNNActivation activation, Value operand) {
Value res;

std::vector<mlir::NamedAttribute> attributes;
if (activation.alpha) {
attributes.emplace_back(
rewriter.getNamedAttr("alpha", activation.alpha.value()));
}
if (activation.beta) {
attributes.emplace_back(
rewriter.getNamedAttr("beta", activation.beta.value()));
}
Type resType = operand.getType();

// Change equality to be case insensitive.
if (activation.name.equals_insensitive("relu"))
res = rewriter.create<ONNXReluOp>(loc, resType, operand);
else if (activation.name.equals_insensitive("tanh"))
res = rewriter.create<ONNXTanhOp>(loc, resType, operand);
else if (activation.name.equals_insensitive("sigmoid"))
res = rewriter.create<ONNXSigmoidOp>(loc, resType, operand);
else if (activation.name.equals_insensitive("affine"))
llvm_unreachable("Unsupported activation");
else if (activation.name.equals_insensitive("leakyrelu"))
res = rewriter.create<ONNXLeakyReluOp>(loc, resType, operand, attributes);
else if (activation.name.equals_insensitive("thresholdedrelu"))
res = rewriter.create<ONNXThresholdedReluOp>(
loc, resType, operand, attributes);
else if (activation.name.equals_insensitive("scaledtanh"))
llvm_unreachable("Unsupported activation");
else if (activation.name.equals_insensitive("hardsigmoid"))
res = rewriter.create<ONNXHardSigmoidOp>(loc, resType, operand, attributes);
else if (activation.name.equals_insensitive("elu"))
res = rewriter.create<ONNXEluOp>(loc, resType, operand, attributes);
else if (activation.name.equals_insensitive("softsign"))
res = rewriter.create<ONNXSoftsignOp>(loc, resType, operand);
else if (activation.name.equals_insensitive("softplus"))
res = rewriter.create<ONNXSoftplusOp>(loc, resType, operand);
else
llvm_unreachable("Unsupported activation");

return res;
}

} // namespace onnx_mlir
55 changes: 55 additions & 0 deletions src/Conversion/ONNXConversionCommon/RNN/RNNBase.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===//
//
// Copyright 2019-2024 The IBM Research Authors.
// Modifications Copyright 2023-2024
//
// =============================================================================
//
// This file defines common base utilities for lowering the ONNX RNN Operators.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"

#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"

static constexpr llvm::StringRef FORWARD = "forward";
static constexpr llvm::StringRef REVERSE = "reverse";
static constexpr llvm::StringRef BIDIRECTIONAL = "bidirectional";

namespace onnx_mlir {

struct RNNActivation {
llvm::StringRef name;
std::optional<mlir::FloatAttr> alpha;
std::optional<mlir::FloatAttr> beta;
};

/// Get a dimension of the tensor's shape.
int64_t dimAt(mlir::Value val, int index);

/// Apply an activation function on a given operand.
mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc,
RNNActivation activation, mlir::Value operand);

// Override the following methods when lowering an RNN operation:
// - hasAllNoneOutput
// - getActivationPack

// Check whether all outputs have NoneType or not.
template <typename RNNOp>
bool hasAllNoneOutput(RNNOp *op);

// Obtain activations functions for a specific operation.
template <typename RNNOp, typename A>
std::tuple<A, A> getActivationPack(RNNOp *op);

} // namespace onnx_mlir
9 changes: 5 additions & 4 deletions src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ add_onnx_mlir_library(OMONNXToKrnl
Sequence/SequenceInsert.cpp
Sequence/SequenceLength.cpp
Tensor/ArgMinMax.cpp
Tensor/Compress.cpp
Tensor/Compress.cpp
Tensor/Concat.cpp
Tensor/ConcatShapeTranspose.cpp
Tensor/Constant.cpp
Expand All @@ -57,17 +57,17 @@ add_onnx_mlir_library(OMONNXToKrnl
Tensor/GatherND.cpp
Tensor/Identity.cpp
Tensor/NonZero.cpp
Tensor/OneHot.cpp
Tensor/OneHot.cpp
Tensor/Pad.cpp
Tensor/PrintSignature.cpp
Tensor/Range.cpp
Tensor/Reshape.cpp
Tensor/Resize.cpp
Tensor/ReverseSequence.cpp
Tensor/ScatterElements.cpp
Tensor/ScatterND.cpp
Tensor/ScatterND.cpp
Tensor/Shape.cpp
Tensor/Size.cpp
Tensor/Size.cpp
Tensor/Slice.cpp
Tensor/SpaceToDepth.cpp
Tensor/Split.cpp
Expand All @@ -79,6 +79,7 @@ add_onnx_mlir_library(OMONNXToKrnl

LINK_LIBS PUBLIC
OMAccelerator
OMONNXConversionCommon
OMONNXOps
OMSupport
MLIRFuncDialect
Expand Down
Loading

0 comments on commit 9ec2829

Please sign in to comment.