Skip to content

Commit

Permalink
rename StableHlo to Stablehlo
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Xu <[email protected]>
  • Loading branch information
Connor-XY committed Dec 27, 2023
1 parent 9003ce1 commit c9bee0e
Show file tree
Hide file tree
Showing 46 changed files with 463 additions and 463 deletions.
2 changes: 1 addition & 1 deletion src/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ add_subdirectory(KrnlSeqToMemref)
add_subdirectory(ONNXToTOSA)

if (ONNX_MLIR_ENABLE_STABLEHLO)
add_subdirectory(ONNXToStableHlo)
add_subdirectory(ONNXToStablehlo)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ install(TARGETS
DESTINATION lib
)

add_onnx_mlir_library(OMONNXToStableHlo
ConvertONNXToStableHlo.cpp
ONNXToStableHloCommon.cpp
add_onnx_mlir_library(OMONNXToStablehlo
ConvertONNXToStablehlo.cpp
ONNXToStablehloCommon.cpp
DialectBuilder.cpp

Math/Clip.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,85 @@
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ConvertONNXToStableHlo.cpp - ONNX dialects to StableHlo lowering
//====------ ConvertONNXToStablehlo.cpp - ONNX dialects to Stablehlo lowering
//-------===//
//
// Copyright 2022
//
// =============================================================================
//
// This file implements the lowering of frontend operations to a combination of
// StableHlo IR and standard operations.
// Stablehlo IR and standard operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"

#include "src/Conversion/ONNXToStableHlo/ONNXToStableHloCommon.hpp"
#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void populateONNXToStableHloConversionPattern(
void populateONNXToStablehloConversionPattern(
RewritePatternSet &patterns, MLIRContext *ctx, bool enableUnroll) {
// Math
populateLoweringONNXClipOpToStableHloPattern(patterns, ctx);
populateLoweringONNXElementwiseOpToStableHloPattern(patterns, ctx);
populateLoweringONNXGemmOpToStableHloPattern(patterns, ctx);
populateLoweringONNXMatMulOpToStableHloPattern(patterns, ctx);
populateLoweringONNXReductionOpToStableHloPattern(patterns, ctx);
populateLoweringONNXClipOpToStablehloPattern(patterns, ctx);
populateLoweringONNXElementwiseOpToStablehloPattern(patterns, ctx);
populateLoweringONNXGemmOpToStablehloPattern(patterns, ctx);
populateLoweringONNXMatMulOpToStablehloPattern(patterns, ctx);
populateLoweringONNXReductionOpToStablehloPattern(patterns, ctx);
// Neural network
populateLoweringONNXConvOpToStableHloPattern(patterns, ctx);
populateLoweringONNXConvTransposeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXNormalizationOpToStableHloPattern(patterns, ctx);
populateLoweringONNXPoolingOpToStableHloPattern(patterns, ctx);
populateLoweringONNXConvOpToStablehloPattern(patterns, ctx);
populateLoweringONNXConvTransposeOpToStablehloPattern(patterns, ctx);
populateLoweringONNXNormalizationOpToStablehloPattern(patterns, ctx);
populateLoweringONNXPoolingOpToStablehloPattern(patterns, ctx);
// Recurrent neural network
populateLoweringONNXLSTMOpToStableHloPattern(patterns, ctx, enableUnroll);
populateLoweringONNXLSTMOpToStablehloPattern(patterns, ctx, enableUnroll);
// Tensor
populateLoweringONNXArgMaxOpToStableHloPattern(patterns, ctx);
populateLoweringONNXConcatOpToStableHloPattern(patterns, ctx);
populateLoweringONNXConstantOpToStableHloPattern(patterns, ctx);
populateLoweringONNXDepthToSpaceOpToStableHloPattern(patterns, ctx);
populateLoweringONNXExpandOpToStableHloPattern(patterns, ctx);
populateLoweringONNXFlattenOpToStableHloPattern(patterns, ctx);
populateLoweringONNXGatherOpToStableHloPattern(patterns, ctx);
populateLoweringONNXGatherElementsOpToStableHloPattern(patterns, ctx);
populateLoweringONNXIdentityOpToStableHloPattern(patterns, ctx);
populateLoweringONNXOneHotOpToStableHloPattern(patterns, ctx);
populateLoweringONNXPadOpToStableHloPattern(patterns, ctx);
populateLoweringONNXReshapeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXScatterNDOpToStableHloPattern(patterns, ctx);
populateLoweringONNXShapeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXSliceOpToStableHloPattern(patterns, ctx);
populateLoweringONNXSplitOpToStableHloPattern(patterns, ctx);
populateLoweringONNXSqueezeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXTileOpToStableHloPattern(patterns, ctx);
populateLoweringONNXTransposeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXUnsqueezeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXArgMaxOpToStablehloPattern(patterns, ctx);
populateLoweringONNXConcatOpToStablehloPattern(patterns, ctx);
populateLoweringONNXConstantOpToStablehloPattern(patterns, ctx);
populateLoweringONNXDepthToSpaceOpToStablehloPattern(patterns, ctx);
populateLoweringONNXExpandOpToStablehloPattern(patterns, ctx);
populateLoweringONNXFlattenOpToStablehloPattern(patterns, ctx);
populateLoweringONNXGatherOpToStablehloPattern(patterns, ctx);
populateLoweringONNXGatherElementsOpToStablehloPattern(patterns, ctx);
populateLoweringONNXIdentityOpToStablehloPattern(patterns, ctx);
populateLoweringONNXOneHotOpToStablehloPattern(patterns, ctx);
populateLoweringONNXPadOpToStablehloPattern(patterns, ctx);
populateLoweringONNXReshapeOpToStablehloPattern(patterns, ctx);
populateLoweringONNXScatterNDOpToStablehloPattern(patterns, ctx);
populateLoweringONNXShapeOpToStablehloPattern(patterns, ctx);
populateLoweringONNXSliceOpToStablehloPattern(patterns, ctx);
populateLoweringONNXSplitOpToStablehloPattern(patterns, ctx);
populateLoweringONNXSqueezeOpToStablehloPattern(patterns, ctx);
populateLoweringONNXTileOpToStablehloPattern(patterns, ctx);
populateLoweringONNXTransposeOpToStablehloPattern(patterns, ctx);
populateLoweringONNXUnsqueezeOpToStablehloPattern(patterns, ctx);
}

//===----------------------------------------------------------------------===//
// Frontend to StableHlo Dialect lowering pass
// Frontend to Stablehlo Dialect lowering pass
//===----------------------------------------------------------------------===//

struct FrontendToStableHloLoweringPass
: public PassWrapper<FrontendToStableHloLoweringPass,
struct FrontendToStablehloLoweringPass
: public PassWrapper<FrontendToStablehloLoweringPass,
OperationPass<ModuleOp>> {

StringRef getArgument() const override { return "convert-onnx-to-stablehlo"; }

StringRef getDescription() const override {
return "Lower frontend ops to StableHlo dialect.";
return "Lower frontend ops to Stablehlo dialect.";
}

// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
FrontendToStableHloLoweringPass() = default;
FrontendToStableHloLoweringPass(const FrontendToStableHloLoweringPass &pass)
: PassWrapper<FrontendToStableHloLoweringPass,
FrontendToStablehloLoweringPass() = default;
FrontendToStablehloLoweringPass(const FrontendToStablehloLoweringPass &pass)
: PassWrapper<FrontendToStablehloLoweringPass,
OperationPass<ModuleOp>>() {}
FrontendToStableHloLoweringPass(bool enableUnroll) {
FrontendToStablehloLoweringPass(bool enableUnroll) {
// Below, need explicit assignment to enable implicit conversion of bool
// to Option<bool>.
this->enableUnroll = enableUnroll;
Expand All @@ -103,7 +103,7 @@ struct FrontendToStableHloLoweringPass
llvm::cl::init(true)};
};

void FrontendToStableHloLoweringPass::runOnOperation() {
void FrontendToStablehloLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
Expand All @@ -122,7 +122,7 @@ void FrontendToStableHloLoweringPass::runOnOperation() {
RewritePatternSet patterns(&getContext());

// Define patterns.
populateONNXToStableHloConversionPattern(
populateONNXToStablehloConversionPattern(
patterns, &getContext(), enableUnroll);

// add illegal op
Expand All @@ -136,12 +136,12 @@ void FrontendToStableHloLoweringPass::runOnOperation() {
}
}

std::unique_ptr<Pass> createLowerToStableHloPass() {
return std::make_unique<FrontendToStableHloLoweringPass>();
std::unique_ptr<Pass> createLowerToStablehloPass() {
return std::make_unique<FrontendToStablehloLoweringPass>();
}

std::unique_ptr<Pass> createLowerToStableHloPass(bool enableUnroll) {
return std::make_unique<FrontendToStableHloLoweringPass>(enableUnroll);
std::unique_ptr<Pass> createLowerToStablehloPass(bool enableUnroll) {
return std::make_unique<FrontendToStablehloLoweringPass>(enableUnroll);
}

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
* SPDX-License-Identifier: Apache-2.0
*/

//====------ DialectBuilder.cpp - StableHlo dialect builder
//====------ DialectBuilder.cpp - Stablehlo dialect builder
//--------------------===//
//
// Copyright 2022
//
// =============================================================================
//
// This file contains dialect builder for StableHlo dialect.
// This file contains dialect builder for Stablehlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToStableHlo/DialectBuilder.hpp"
#include "src/Conversion/ONNXToStablehlo/DialectBuilder.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Support/TypeUtilities.hpp"
Expand Down Expand Up @@ -200,11 +200,11 @@ Value OnnxToStablehloBuilder::transpose(const Value input,
}

// =============================================================================
// IndexExpr Builder for Lowering using Shape/StableHlo Dialect.
// IndexExpr Builder for Lowering using Shape/Stablehlo Dialect.
// =============================================================================

// Return null if none is found.
ElementsAttr IndexExprBuilderForStableHlo::getConst(Value value) {
ElementsAttr IndexExprBuilderForStablehlo::getConst(Value value) {
auto definingOp = value.getDefiningOp();
// If we have a cast between index/integer, skip it, i.e. get the defining op
// that is the input to the cast.
Expand All @@ -222,7 +222,7 @@ ElementsAttr IndexExprBuilderForStableHlo::getConst(Value value) {
return nullptr;
}

Value IndexExprBuilderForStableHlo::getVal(Value intArrayVal, uint64_t i) {
Value IndexExprBuilderForStablehlo::getVal(Value intArrayVal, uint64_t i) {
Type elemType = getElementType(intArrayVal.getType());
if (!elemType.isa<IndexType>()) {
Type indexTensorType = RankedTensorType::get(
Expand All @@ -235,7 +235,7 @@ Value IndexExprBuilderForStableHlo::getVal(Value intArrayVal, uint64_t i) {
return createShape.getExtent(intArrayVal, i);
}

Value IndexExprBuilderForStableHlo::getShapeVal(
Value IndexExprBuilderForStablehlo::getShapeVal(
Value tensorOrMemrefValue, uint64_t i) {
ShapeBuilder createShape(*this);
return createShape.dim(tensorOrMemrefValue, i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
* SPDX-License-Identifier: Apache-2.0
*/

//====------ DialectBuilder.hpp - StableHlo dialect builder
//====------ DialectBuilder.hpp - Stablehlo dialect builder
//--------------------===//
//
// Copyright 2022
//
// =============================================================================
//
// This file contains dialect builder for StableHlo dialect.
// This file contains dialect builder for Stablehlo dialect.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -104,13 +104,13 @@ struct OnnxToStablehloBuilder : public OnnxBuilder {
// IndexExpr Builder for Shape lowering
// =============================================================================

struct IndexExprBuilderForStableHlo : IndexExprBuilder {
IndexExprBuilderForStableHlo(mlir::Location loc) : IndexExprBuilder(loc) {}
IndexExprBuilderForStableHlo(mlir::OpBuilder &b, mlir::Location loc)
struct IndexExprBuilderForStablehlo : IndexExprBuilder {
IndexExprBuilderForStablehlo(mlir::Location loc) : IndexExprBuilder(loc) {}
IndexExprBuilderForStablehlo(mlir::OpBuilder &b, mlir::Location loc)
: IndexExprBuilder(b, loc) {}
IndexExprBuilderForStableHlo(const DialectBuilder &db)
IndexExprBuilderForStablehlo(const DialectBuilder &db)
: IndexExprBuilder(db) {}
virtual ~IndexExprBuilderForStableHlo() {}
virtual ~IndexExprBuilderForStablehlo() {}

protected:
mlir::ElementsAttr getConst(mlir::Value value) final;
Expand Down Expand Up @@ -146,16 +146,16 @@ struct MultiDialectBuilder<OnnxToStablehloBuilder, Ts...>
OnnxToStablehloBuilder stablehloOnnx;
};

// Recursive class specialized for IndexExprBuilderForStableHlo referred to as
// Recursive class specialized for IndexExprBuilderForStablehlo referred to as
// stableHloIE.
template <class... Ts>
struct MultiDialectBuilder<IndexExprBuilderForStableHlo, Ts...>
struct MultiDialectBuilder<IndexExprBuilderForStablehlo, Ts...>
: MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), stableHloIE(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), stableHloIE(db) {}
IndexExprBuilderForStableHlo stableHloIE;
IndexExprBuilderForStablehlo stableHloIE;
};

} // namespace onnx_mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
//
// =============================================================================
//
// This file lowers the ONNX Clip Operator to StableHlo dialect.
// This file lowers the ONNX Clip Operator to Stablehlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToStableHlo/DialectBuilder.hpp"
#include "src/Conversion/ONNXToStableHlo/ONNXToStableHloCommon.hpp"
#include "src/Conversion/ONNXToStablehlo/DialectBuilder.hpp"
#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;
Expand All @@ -24,16 +24,16 @@ namespace onnx_mlir {
// Scalar unary ops for lowering ONNXClipOp
//===----------------------------------------------------------------------===//

struct ONNXClipOpLoweringToStableHlo : public ConversionPattern {
ONNXClipOpLoweringToStableHlo(MLIRContext *ctx)
struct ONNXClipOpLoweringToStablehlo : public ConversionPattern {
ONNXClipOpLoweringToStablehlo(MLIRContext *ctx)
: ConversionPattern(ONNXClipOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();

ONNXClipOpAdaptor operandAdaptor(operands);
IndexExprBuilderForStableHlo createIE(rewriter, loc);
IndexExprBuilderForStablehlo createIE(rewriter, loc);
ONNXClipOpShapeHelper shapeHelper(op, operands, &createIE);
auto shapeComputed = shapeHelper.computeShape();
assert(succeeded(shapeComputed) && "Could not compute output shape");
Expand Down Expand Up @@ -66,9 +66,9 @@ struct ONNXClipOpLoweringToStableHlo : public ConversionPattern {
}
};

void populateLoweringONNXClipOpToStableHloPattern(
void populateLoweringONNXClipOpToStablehloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXClipOpLoweringToStableHlo>(ctx);
patterns.insert<ONNXClipOpLoweringToStablehlo>(ctx);
}

} // namespace onnx_mlir
Loading

0 comments on commit c9bee0e

Please sign in to comment.