diff --git a/src/Conversion/ONNXToMhlo/CMakeLists.txt b/src/Conversion/ONNXToMhlo/CMakeLists.txt index 9b5e57a87d..cd26fb1efb 100644 --- a/src/Conversion/ONNXToMhlo/CMakeLists.txt +++ b/src/Conversion/ONNXToMhlo/CMakeLists.txt @@ -7,7 +7,9 @@ add_onnx_mlir_library(OMONNXToMhlo Math/Elementwise.cpp Math/Gemm.cpp Math/Softmax.cpp + NN/Normalization.cpp NN/Pooling.cpp + Tensor/Concat.cpp Tensor/Constant.cpp Tensor/Reshape.cpp @@ -24,7 +26,7 @@ add_onnx_mlir_library(OMONNXToMhlo OMSupport ) - target_include_directories(OMONNXToMhlo PUBLIC - ${CMAKE_SOURCE_DIR}/third_party/mlir-hlo/include - ${CMAKE_BINARY_DIR}/third_party/mlir-hlo/include +target_include_directories(OMONNXToMhlo PUBLIC + ${ONNX_MLIR_SRC_ROOT}/third_party/mlir-hlo/include + ${ONNX_MLIR_BIN_ROOT}/third_party/mlir-hlo/include ) diff --git a/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp b/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp index 02ea7eebf8..f84b8f4750 100644 --- a/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp +++ b/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp @@ -26,8 +26,10 @@ void populateONNXToMhloConversionPattern( populateLoweringONNXSoftmaxOpToMhloPattern(patterns, ctx); populateLoweringONNXGemmOpToMhloPattern(patterns, ctx); // Neural network + populateLoweringONNXNormalizationOpToMhloPattern(patterns, ctx); populateLoweringONNXPoolingOpToMhloPattern(patterns, ctx); // Tensor + populateLoweringONNXConcatOpToMhloPattern(patterns, ctx); populateLoweringONNXConstantOpToMhloPattern(patterns, ctx); populateLoweringONNXReshapeOpToMhloPattern(patterns, ctx); } diff --git a/src/Conversion/ONNXToMhlo/NN/Normalization.cpp b/src/Conversion/ONNXToMhlo/NN/Normalization.cpp new file mode 100644 index 0000000000..4a1dd6a20a --- /dev/null +++ b/src/Conversion/ONNXToMhlo/NN/Normalization.cpp @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------- Normalization.cpp - Lowering Normalization Ops -----------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers ONNX Normalization Operators to Mhlo dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct ONNXBatchNormalizationInferenceModeOpLoweringToMhlo + : public ConversionPattern { + ONNXBatchNormalizationInferenceModeOpLoweringToMhlo(MLIRContext *ctx) + : ConversionPattern( + mlir::ONNXBatchNormalizationInferenceModeOp::getOperationName(), 1, + ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // batchnorm{epsilon}(x, scale, bias, mean, variance) = + // scale * (x - mean) / sqrt(variance + epsilon) + bias + ONNXBatchNormalizationInferenceModeOpAdaptor operandAdaptor( + operands, op->getAttrDictionary()); + Location loc = op->getLoc(); + + Value operand = operandAdaptor.X(); + Value scale = operandAdaptor.scale(); + Value bias = operandAdaptor.B(); + Value mean = operandAdaptor.mean(); + Value variance = operandAdaptor.var(); + llvm::APFloat eps = operandAdaptor.epsilon(); + + Value result = rewriter.create(loc, + op->getResultTypes(), operand, scale, bias, mean, variance, eps, 1); + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXNormalizationOpToMhloPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp b/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp index 2227088d96..ad5152603c 100644 --- a/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp +++ b/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp @@ -58,9 +58,13 @@ void populateLoweringONNXSoftmaxOpToMhloPattern( void populateLoweringONNXGemmOpToMhloPattern( RewritePatternSet &, MLIRContext *); // `NN` directory methods: +void populateLoweringONNXNormalizationOpToMhloPattern( + RewritePatternSet &, MLIRContext *); void populateLoweringONNXPoolingOpToMhloPattern( RewritePatternSet &, MLIRContext *); // `Tensor` directory methods: +void populateLoweringONNXConcatOpToMhloPattern( + RewritePatternSet &, MLIRContext *); void populateLoweringONNXConstantOpToMhloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXReshapeOpToMhloPattern( diff --git a/src/Conversion/ONNXToMhlo/Tensor/Concat.cpp b/src/Conversion/ONNXToMhlo/Tensor/Concat.cpp new file mode 100644 index 0000000000..7b843aec14 --- /dev/null +++ b/src/Conversion/ONNXToMhlo/Tensor/Concat.cpp @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Concat.cpp - Lowering Concat Op -------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Concat Operator to Mhlo dialect. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Debug.h" + +#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" +#include "src/Support/TypeUtilities.hpp" + +#define DEBUG_TYPE "onnx_to_mhlo" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct ONNXConcatOpLoweringToMhlo : public ConversionPattern { + ONNXConcatOpLoweringToMhlo(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + Location loc = op->getLoc(); + ONNXConcatOpAdaptor operandAdaptor(operands); + ONNXConcatOp concatOp = llvm::cast(op); + + assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result"); + Type resultType = op->getResult(0).getType(); + if (!onnx_mlir::isRankedShapedType(resultType)) { + LLVM_DEBUG(llvm::dbgs() << "Concat Output Is Not Ranked\n"); + return failure(); + } + int64_t rank = onnx_mlir::getRank(resultType); + int64_t axis = concatOp.axis(); + axis = axis >= 0 ? axis : rank + axis; + assert(axis >= -rank && axis <= rank - 1 && "Axis out of rank range"); + + ValueRange inputs = operandAdaptor.inputs(); + Value result = rewriter.create( + loc, op->getResultTypes(), inputs, rewriter.getI64IntegerAttr(axis)); + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXConcatOpToMhloPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Support/TypeUtilities.cpp b/src/Support/TypeUtilities.cpp index 257c1a0d2d..f5401b2047 100644 --- a/src/Support/TypeUtilities.cpp +++ b/src/Support/TypeUtilities.cpp @@ -36,6 +36,12 @@ ArrayRef getShape(Type ty) { return ty.cast().getShape(); } +/// Get rank. +int64_t getRank(Type ty) { + assert(isRankedShapedType(ty) && "Type must be ranked"); + return ty.cast().getRank(); +} + /// Get the number of elements. int64_t getNumberOfElements(Type ty) { ArrayRef shape = getShape(ty); diff --git a/src/Support/TypeUtilities.hpp b/src/Support/TypeUtilities.hpp index 929599843b..50c781f1df 100644 --- a/src/Support/TypeUtilities.hpp +++ b/src/Support/TypeUtilities.hpp @@ -22,6 +22,8 @@ mlir::Type getElementType(mlir::Type ty); bool isRankedShapedType(mlir::Type ty); /// Get shape. llvm::ArrayRef getShape(mlir::Type ty); +/// Get rank. +int64_t getRank(mlir::Type ty); /// Get the number of elements. int64_t getNumberOfElements(mlir::Type ty); /// Get the element size in bytes. diff --git a/test/mlir/conversion/onnx_to_mhlo/NN/Normalization.mlir b/test/mlir/conversion/onnx_to_mhlo/NN/Normalization.mlir new file mode 100644 index 0000000000..2bac09dafd --- /dev/null +++ b/test/mlir/conversion/onnx_to_mhlo/NN/Normalization.mlir @@ -0,0 +1,19 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s + +func @test_batch_normalization(%arg0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> { + %0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> + %4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32> + "func.return"(%4) : (tensor<1x3x10x10xf32>) -> () +// CHECK-LABEL: func @test_batch_normalization +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK-NEXT: [[VAR_2_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.batch_norm_inference"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, feature_index = 1 : i64} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32> +// CHECK-NEXT: return [[VAR_4_]] : tensor<1x3x10x10xf32> +// CHECK-NEXT: } +} \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_mhlo/Tensor/Concat.mlir b/test/mlir/conversion/onnx_to_mhlo/Tensor/Concat.mlir new file mode 100644 index 0000000000..9e7877f767 --- /dev/null +++ b/test/mlir/conversion/onnx_to_mhlo/Tensor/Concat.mlir @@ -0,0 +1,25 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s + +// Test when output shape is unkown +func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { + %0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> + "func.return"(%0) : (tensor<5x5x?x32xf32>) -> () +// CHECK-LABEL: func @test_concat_dynamic_shape +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x?x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x?x32xf32> +// CHECK-NEXT: } +} + +// ----- + +// Test when axis is negative +func @test_concat_negative_axis(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { + %0 = "onnx.Concat"(%arg0, %arg1) { axis = -2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> + "func.return"(%0) : (tensor<5x5x4x32xf32>) -> () +// CHECK-LABEL: func @test_concat_negative_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x1x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x4x32xf32> +// CHECK-NEXT: } +} \ No newline at end of file