From 64c3d18c3f8c634c170652d2650fddb46560346e Mon Sep 17 00:00:00 2001 From: Philipp Braun <3589810@gmail.com> Date: Mon, 3 Oct 2022 19:01:04 +0200 Subject: [PATCH] [TOSA] Update type converter and unary ops (#1553) Signed-off-by: Philipp Braun Co-authored-by: Alexandre Eichenberger --- src/Conversion/ONNXToTOSA/CMakeLists.txt | 3 + .../ONNXToTOSA/ConvertONNXToTOSA.cpp | 80 ++++++---------- .../ONNXToTOSA/Math/Elementwise.cpp | 91 +++++++++++++++++++ .../ONNXToTOSA/ONNXToTOSACommon.hpp | 64 +++++++++++++ .../onnx_to_tosa/Math/Elementwise.mlir | 37 ++++++++ test/mlir/tosa/onnx_lowering.mlir | 11 --- 6 files changed, 223 insertions(+), 63 deletions(-) create mode 100644 src/Conversion/ONNXToTOSA/Math/Elementwise.cpp create mode 100644 src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp create mode 100644 test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir delete mode 100644 test/mlir/tosa/onnx_lowering.mlir diff --git a/src/Conversion/ONNXToTOSA/CMakeLists.txt b/src/Conversion/ONNXToTOSA/CMakeLists.txt index f5dace24c4..3f72a546d4 100644 --- a/src/Conversion/ONNXToTOSA/CMakeLists.txt +++ b/src/Conversion/ONNXToTOSA/CMakeLists.txt @@ -2,6 +2,9 @@ add_onnx_mlir_library(OMONNXToTOSA ConvertONNXToTOSA.cpp + + Math/Elementwise.cpp + LINK_LIBS PUBLIC OMONNXOps MLIRTosaDialect diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 06067c4b6a..67bdbc2c39 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -12,54 +12,18 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Dialect/ONNX/DialectBuilder.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" -#include "src/Pass/Passes.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" using namespace mlir; namespace onnx_mlir { -// This defines a template to construct ops whose legalizations are -// specialized. -template -class ConvertOnnxOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OnnxOpT::Adaptor; - LogicalResult matchAndRewrite(OnnxOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -template <> -LogicalResult ConvertOnnxOp::matchAndRewrite(ONNXReluOp op, - OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Value input = adaptor.X(); - auto inputTy = input.getType().dyn_cast(); - - if (!inputTy) - return op.emitError("Only Tensor types supported in TOSA"); - - if (!inputTy.getElementType().isa()) { - return op.emitError( - "Only floating-point datatype legalization currently supported"); - } - - // Rescale the clampIn for quantized types. TBD - // Maps to tosa.clamp which has both int and fp limits. - Value clampIn = input; - - rewriter.replaceOpWithNewOp(op, op.getType(), clampIn, - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(std::numeric_limits::max()), - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); - return success(); +void populateONNXToTOSAConversionPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + // Math + populateLoweringONNXElementwiseOpToTOSAPattern( + target, patterns, typeConverter, ctx); } // Performs lowering to TOSA dialect @@ -79,24 +43,36 @@ struct FrontendToTosaLoweringPass }; void FrontendToTosaLoweringPass::runOnOperation() { + ModuleOp module = getOperation(); + // Define final conversion target MLIRContext *context = &getContext(); RewritePatternSet patterns(context); ConversionTarget target(*context); + // We use the type converter to legalize types before any conversion patterns + // are executed. This ensures that we do not need to trigger separate + // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - + typeConverter.addConversion([](Type type) -> Optional { + if (isTOSASignedInt(type) || isTOSAFloat(type)) + return type; + return llvm::None; + }); + typeConverter.addConversion([&](TensorType type) -> Optional { + if (typeConverter.isLegal(type.getElementType())) + return type; + return llvm::None; + }); + + // Define legal dialects and operations target.addLegalDialect(); -#define INSERT_ONNXOP_PATTERN(OnnxOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_ONNXOP_PATTERN(ONNXReluOp); -#undef INSERT_ONNXOP_PATTERN + // Define patterns + populateONNXToTOSAConversionPattern(target, patterns, typeConverter, context); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); + } } std::unique_ptr createConvertONNXToTOSAPass() { diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp new file mode 100644 index 0000000000..b76c2a88c7 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Elementwise.cpp - Elementwise Op --------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX element-wise operators to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/TypeUtilities.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +template <> +struct TOSADialectOp { + using Op = tosa::NegateOp; +}; + +namespace { + +// Element-wise unary ops lowering to TOSA dialect. +//===----------------------------------------------------------------------===// +template +class ONNXElementwiseUnaryOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ElementwiseUnaryOp::Adaptor; + LogicalResult matchAndRewrite(ElementwiseUnaryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp>( + op, op.getType(), adaptor.X()); + return success(); + } +}; + +class ONNXFloorOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXFloorOp::Adaptor; + LogicalResult matchAndRewrite(ONNXFloorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto scalarType = getElementTypeOrSelf(adaptor.X()); + if (!isTOSAFloat(scalarType)) + return rewriter.notifyMatchFailure( + op, "`tosa.floor` only supports float types"); + + rewriter.replaceOpWithNewOp(op, op.getType(), adaptor.X()); + return success(); + } +}; + +class ONNXReluOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value input = adaptor.X(); + + // Quantized types are not supported right now (in type conversion). + // Once they are, the input should be rescaled for quantized types. (TBD) + // Maps to `tosa.clamp` which has both int and fp limits. + rewriter.replaceOpWithNewOp(op, op.getType(), input, + rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(std::numeric_limits::max()), + rewriter.getF32FloatAttr(0.0f), + rewriter.getF32FloatAttr(std::numeric_limits::max())); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert, + ONNXFloorOpLoweringToTOSA, ONNXReluOpLoweringToTOSA>(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp new file mode 100644 index 0000000000..c9983f4811 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Transform/ONNX/ConstPropHelper.hpp" + +//===----------------------------------------------------------------------===// +// Functions to add lowering patterns for frontend operations. +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { + +//===----------------------------------------------------------------------===// +// Check for valid TOSA types. +//===----------------------------------------------------------------------===// + +inline bool isTOSASignedInt(Type type) { + IntegerType intType = type.dyn_cast(); + std::set intWidth{8, 16, 32, 48, 64}; + return intType && intType.isSigned() && + (intWidth.find(intType.getWidth()) != intWidth.end()); +} + +inline bool isTOSAFloat(Type type) { + return type.isa(); +} + +//===----------------------------------------------------------------------===// +// This is to get a TOSA operation of a given type for a specific operation. +//===----------------------------------------------------------------------===// +template +struct TOSADialectOp { + using Op = void; +}; + +template +using TOSAOp = typename TOSADialectOp::Op; + +// `Math` directory methods: +void populateLoweringONNXElementwiseOpToTOSAPattern( + ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *); +} // namespace onnx_mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir new file mode 100644 index 0000000000..b63b0cac7d --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -0,0 +1,37 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_relu +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +func.func @test_relu_dynamic(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_relu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK-NEXT: return [[VAR_0_]] : tensor +// CHECK-NEXT: } +} + +func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_neg +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> +} + +func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_floor +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> +} diff --git a/test/mlir/tosa/onnx_lowering.mlir b/test/mlir/tosa/onnx_lowering.mlir deleted file mode 100644 index a6a2bd2af0..0000000000 --- a/test/mlir/tosa/onnx_lowering.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: onnx-mlir-opt -O3 --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s - -func.func private @test_relu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor<*xf32> - "func.return"(%0) : (tensor<*xf32>) -> () - -// CHECK-LABEL: func private @test_relu( -// CHECK-SAME: [[INPUT:%.+]]: tensor) -> tensor { -// CHECK: [[OUTPUT:%.+]] = "tosa.clamp"([[INPUT]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor -// CHECK: return [[OUTPUT]] : tensor -}