From 0f021bc9fbe45a3c2ad7c69ab63da71f2b99b597 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Mon, 30 Oct 2023 09:31:10 -0700 Subject: [PATCH 1/2] const prop: round+saturate cast from fp to int types (#2593) Signed-off-by: Soren Lassen --- src/Compiler/CompilerOptions.cpp | 9 + src/Compiler/CompilerOptions.hpp | 1 + src/Compiler/CompilerPasses.cpp | 5 +- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 259 ++++++++++++++---- .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 23 +- src/Pass/Passes.hpp | 2 +- src/Transform/ONNX/ConstProp.cpp | 71 ++--- test/mlir/onnx/onnx_constprop.mlir | 31 ++- .../onnx/onnx_constprop_expansion_bound.mlir | 23 -- test/mlir/onnx/onnx_constprop_flags.mlir | 40 +++ test/unit/SmallFP/TestSmallFP.cpp | 27 ++ 11 files changed, 360 insertions(+), 131 deletions(-) delete mode 100644 test/mlir/onnx/onnx_constprop_expansion_bound.mlir create mode 100644 test/mlir/onnx/onnx_constprop_flags.mlir diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 4b4b9a5126..13ce6229da 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -32,6 +32,7 @@ std::string mtriple; // common for both std::string mcpu; // common for both std::string march; // common for both InstrumentStages instrumentStage; // common for both +bool onnxConstPropRoundFPToInt; // common for both int onnxConstPropExpansionBound; // common for both std::vector onnxConstPropDisablePatterns; // common for both bool enableONNXHybridPass; // common for both @@ -156,6 +157,14 @@ static llvm::cl::opt instrumentStageOpt( APPLY_TO_ACCELERATORS(ACCEL_INSTRUMENTSTAGE_CL_ENUM)), llvm::cl::init(Onnx), llvm::cl::cat(OnnxMlirCommonOptions)); +static llvm::cl::opt onnxConstPropRoundFPToIntOpt( + "onnx-const-prop-round-fp-to-int", + llvm::cl::desc("If true constant propagates onnx.Cast from a floating " + "point type to an integer type by rounding to nearest, " + "ties to even. If false truncates towards zero."), + llvm::cl::location(onnxConstPropRoundFPToInt), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::opt onnxConstPropExpansionBoundOpt( "onnx-const-prop-expansion-bound", llvm::cl::desc("ONNX dialect constant propagation maximum expansion factor." diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index f781e03424..3bab52d355 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -74,6 +74,7 @@ extern std::string mtriple; // common for both extern std::string mcpu; // common for both extern std::string march; // common for both extern InstrumentStages instrumentStage; // common for both +extern bool onnxConstPropRoundFPToInt; // common for both extern int onnxConstPropExpansionBound; // common for both extern std::vector onnxConstPropDisablePatterns; // common for both extern bool enableONNXHybridPass; // common for both diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 706df8eccf..dae13ef5e9 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -45,8 +45,9 @@ namespace onnx_mlir { void configurePasses() { // Set global vector machine support. VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, ""); - configureConstPropONNXToONNXPass(onnxConstPropExpansionBound, - onnxConstPropDisablePatterns, disableConstantProp); + configureConstPropONNXToONNXPass(onnxConstPropRoundFPToInt, + onnxConstPropExpansionBound, onnxConstPropDisablePatterns, + disableConstantProp); configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel, enableParallel, optReport == OptReport::Simd, !disableSimdOption); } diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 73b8fda8c8..9251b49e38 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -276,80 +276,213 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs, }); } +ElementsAttr ElementsAttrBuilder::castElementType( + ElementsAttr elms, Type newElementType) { + if (auto ftype = dyn_cast(newElementType)) { + // TODO: Consider saturating when ftype has no infinity: + // saturate=APFloat::getInf(ftype.getFloatSemantics()).isNaN() + return castToFPElementType(elms, ftype); + } + if (auto itype = dyn_cast(newElementType)) { + return castToIntElementType(elms, itype); + } + llvm_unreachable("unsupported newElementType"); +} + namespace { -using ElementsTransformer = std::function)>; -ElementsTransformer composeTransforms( - ElementsTransformer first, ElementsTransformer second) { - if (first == nullptr) - return second; - else - return [fst = std::move(first), snd = std::move(second)]( - MutableArrayRef dst) { - fst(dst); - snd(dst); - }; +// Rounds (ties to even) and saturates (out of range numbers become MIN or MAX). +// Returns zero if from is NaN, like llvm::APFloat::convertToInteger(). +// From must be a floating point type (double, float, float_16, float_8e5m2). +// To must be an integer type with size <= size(long), i.e., bitwidth <= 64. +// +// TODO: consider making it configurable whether to convert NaN to +// number farthest from zero (like X86 SSE) +// or just highest bit set (like CUDA) or zero +// +// TODO: optimize w/X86 SSE instructions https://stackoverflow.com/a/47347224 +// +template +TO convertIntFromDouble(double from, TO min, TO max) { + if (std::isnan(from)) + return 0; + if (from < static_cast(min)) + return min; + // static_cast(max)) can round to a larger number + // so return max if from is greater or equal, not just if greater + if (from >= max) + return max; + + if (TRUNCATE) + return static_cast(from); + + // llrint recommendation: https://stackoverflow.com/a/47347224 + // rounds to nearest, ties to even, in the default rounding mode + using llrintType = decltype(llrint(from)); + if constexpr (std::is_same_v) { + static_assert( + sizeof(llrintType) >= sizeof(TO), "insufficient llrint range"); + // llrintType is int64_t which doesn't cover the numeric range of uint64_t + // so we work around this by breaking the range into 2 as follows: + uint64_t mid = uint64_t(1) << 63; // middle of uint64_t numeric range + if (from < mid) { + // from is inside llrint's numerical range [-2^63, 2^63) + return llrint(from); + } else { + // subtract and add to translate into and llrint's numeric range and back + return mid + llrint(from - mid); + } + } else { + // llrintType covers the numeric range of TO, namely llrintType is int64_t + // and TO is int64_t or a narrower signed or unsigned type + static_assert(sizeof(llrintType) > sizeof(TO) || + (sizeof(llrintType) == sizeof(TO) && + std::numeric_limits::is_signed), + "insufficient llrint range"); + return llrint(from); + } } -template -struct Caster { - static inline constexpr DstT eval(SrcT src) { return static_cast(src); } -}; +template +auto convertIntFromFP(TO min, TO max) { + return [min, max](WideNum n) -> WideNum { + double from = n.narrow(); + TO to = convertIntFromDouble(from, min, max); + return WideNum::widen>(to); + }; +} -template -using WideCaster = WideNumWrappedFunction>; - -auto wideCaster(BType src, BType dst) -> WideNum (*)(WideNum) { - constexpr BType DBL = BType::DOUBLE, I64 = BType::INT64, U64 = BType::UINT64; - // clang-format off - if (src == DBL && dst == I64) return WideCaster::eval; - if (src == DBL && dst == U64) return WideCaster::eval; - if (src == I64 && dst == DBL) return WideCaster::eval; - if (src == I64 && dst == U64) return WideCaster::eval; - if (src == U64 && dst == DBL) return WideCaster::eval; - if (src == U64 && dst == I64) return WideCaster::eval; - // clang-format on - llvm_unreachable("wideCaster must be called with 2 different wide types"); +template +WideNum isWideNonZero(WideNum n) { + return WideNum::widen(n.narrow>() != 0); } + +template +WideNum wideCast(WideNum n) { + return WideNum::widen>( + static_cast(n.narrow>())); +}; + +template +double wideToDouble(WideNum n) { + return static_cast(n.narrow>()); +}; + } // namespace -ElementsAttr ElementsAttrBuilder::castElementType( - ElementsAttr elms, Type newElementType) { +ElementsAttr ElementsAttrBuilder::castToIntElementType( + ElementsAttr elms, IntegerType newElementType, bool round) { Type oldElementType = elms.getElementType(); if (newElementType == oldElementType) return elms; - ElementsProperties props = getElementsProperties(elms); + Transformer transformer; + if (newElementType.isInteger(1)) { + // Bool: +/-zero cast to 0, everything else including NaN cast to 1. + transformer = wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) { + using cpptype = decltype(wideZero); + return functionTransformer(isWideNonZero); + }); + } else if (isa(oldElementType)) { + constexpr bool ROUND = false, TRUNCATE = true; + unsigned width = newElementType.getWidth(); + if (newElementType.isUnsigned()) { + uint64_t min = 0; + uint64_t max = std::numeric_limits::max() >> (64 - width); + transformer = round ? functionTransformer( + convertIntFromFP(min, max)) + : functionTransformer( + convertIntFromFP(min, max)); + } else { + int64_t min = std::numeric_limits::min() >> (64 - width); + int64_t max = std::numeric_limits::max() >> (64 - width); + transformer = round ? functionTransformer( + convertIntFromFP(min, max)) + : functionTransformer( + convertIntFromFP(min, max)); + } + } else if (isa(oldElementType)) { + // We assume that casts to other integer types don't intend to truncate the + // numeric range and we delay any truncation until the data is read and + // allow the untruncated numbers as inputs to any further transformations. + // + // TODO: Add configuration options to support other behaviors. + // See https://github.com/onnx/onnx-mlir/issues/2209 + if (newElementType.isUnsigned() != oldElementType.isUnsignedInteger()) { + // DisposableElementsAttr requires transformation between integers with + // different signs. + // TODO: Consider relaxing the requirement and omit this transformation. + transformer = newElementType.isUnsigned() + ? functionTransformer(wideCast) + : functionTransformer(wideCast); + } else { + ElementsProperties props = getElementsProperties(elms); + ShapedType newType = elms.getShapedType().clone(newElementType); + return create(newType, props.bufferBType, props.strides, props.buffer, + props.transformer); + } + } else { + llvm_unreachable("unsupported element type"); + } + return doTransform(elms, newElementType, transformer); +} - ShapedType newType = elms.getShapedType().clone(newElementType); - BType newBType = btypeOfMlirType(newElementType); - BType oldBType = btypeOfMlirType(oldElementType); - BType newWideType = wideBTypeOfBType(newBType); - BType oldWideType = wideBTypeOfBType(oldBType); - - auto transformer = - oldWideType == newWideType - ? props.transformer - : composeTransforms(props.transformer, - functionTransformer(wideCaster(oldWideType, newWideType))); - return create(newType, props.bufferBType, props.strides, props.buffer, - std::move(transformer)); +ElementsAttr ElementsAttrBuilder::castToFPElementType( + ElementsAttr elms, FloatType newElementType, bool saturate) { + Type oldElementType = elms.getElementType(); + if (newElementType == oldElementType) + return elms; + + return wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) { + using cpptype = decltype(wideZero); + Transformer transformer; + if (saturate) { + // Smallest is -max for all ONNX fp types. + const double max = APFloat::getLargest(newElementType.getFloatSemantics()) + .convertToDouble(); + // Note that we saturate by clipping which isn't 100% faithful to the + // onnx spec here: https://onnx.ai/onnx/technical/float8.html + // and here: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast + // which, in the case of E4M3FNUZ and E5M2FNUZ, requires infinite values + // to saturate to NaN, whereas we saturate them to smallest/largest with + // clipping. Our clipping implementation matches the reference + // implementation in onnx/reference/ops/op_cast.py. + // See https://github.com/onnx/onnx-mlir/issues/2369 + // + // TODO: Change implementation to match the spec, or change the spec. + transformer = functionTransformer([max](WideNum n) { + double d = wideToDouble(n); + return WideNum::widen( + // Order of operations is important to ensure NaN stays NaN: + d <= -max ? -max : (d >= max ? max : d)); + }); + } else if constexpr (std::is_integral_v) { + transformer = functionTransformer([](WideNum n) { + return WideNum::widen(wideToDouble(n)); + }); + } else { + ElementsProperties props = getElementsProperties(elms); + ShapedType newType = elms.getShapedType().clone(newElementType); + return create(newType, props.bufferBType, props.strides, props.buffer, + props.transformer); + } + return doTransform(elms, newElementType, transformer); + }); } ElementsAttr ElementsAttrBuilder::clip( ElementsAttr elms, WideNum min, WideNum max) { return wideZeroDispatchNonBool(elms.getElementType(), [&](auto wideZero) { using cpptype = decltype(wideZero); - return doTransform( - elms, elms.getElementType(), functionTransformer([min, max](WideNum n) { - constexpr BType TAG = toBType; - cpptype x = n.narrow(); - if (x < min.narrow()) - return min; - if (x > max.narrow()) - return max; - return n; - })); + return transform(elms, elms.getElementType(), [min, max](WideNum n) { + constexpr BType TAG = toBType; + cpptype x = n.narrow(); + if (x < min.narrow()) + return min; + if (x > max.narrow()) + return max; + return n; + }); }); } @@ -983,6 +1116,22 @@ ArrayBuffer ElementsAttrBuilder::getWideNumsAndExpandedStrides( }; } +namespace { +using ElementsTransformer = std::function)>; + +ElementsTransformer composeTransforms( + ElementsTransformer first, ElementsTransformer second) { + if (first == nullptr) + return second; + else + return [fst = std::move(first), snd = std::move(second)]( + MutableArrayRef dst) { + fst(dst); + snd(dst); + }; +} +} // namespace + ElementsAttr ElementsAttrBuilder::doTransform( ElementsAttr elms, Type transformedElementType, Transformer transformer) { ShapedType transformedType = diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index e4210051f8..0414aaabe6 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -116,12 +116,33 @@ class ElementsAttrBuilder { mlir::ElementsAttr where(mlir::ElementsAttr cond, mlir::ElementsAttr lhs, mlir::ElementsAttr rhs, mlir::ShapedType combinedType); - // Returns an ElementsAttr with the elements cast to the given newElementType. + // Returns an ElementsAttr with the elements cast to the given newElementType + // with default choices for rounding (true) and saturation (false). // // Reuses elms' underlying data without a data copy. mlir::ElementsAttr castElementType( mlir::ElementsAttr elms, mlir::Type newElementType); + // Returns an ElementsAttr with the elements cast to the given intElementType. + // + // If round==true and elms has floating point numbers type then they are + // rounded to nearest integer, ties to even, otherwise they are truncated + // towards zero. + // + // Reuses elms' underlying data without a data copy. + mlir::ElementsAttr castToIntElementType(mlir::ElementsAttr elms, + mlir::IntegerType newElementType, bool round = true); + + // Returns an ElementsAttr with the elements cast to the given fpElementType. + // + // If saturate==true and newElementType has +/-infinity then out of range + // numbers are cast to +/-infinity, otherwise they are clipped to the finite + // range. + // + // Reuses elms' underlying data without a data copy. + mlir::ElementsAttr castToFPElementType(mlir::ElementsAttr elms, + mlir::FloatType newElementType, bool saturate = false); + // Returns an ElementsAttr with the values clipped to the range [min, max]. // // Reuses elms' underlying data without a data copy. diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index f6c625626b..731fa6b503 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -45,7 +45,7 @@ std::unique_ptr createConvOptONNXToONNXPass( std::unique_ptr createShapeInferencePass(); // To configure ConstPropONNXToONNXPass at program start. -void configureConstPropONNXToONNXPass(int expansionBound, +void configureConstPropONNXToONNXPass(bool roundFPToInt, int expansionBound, llvm::ArrayRef disabledPatterns, bool constantPropIsDisabled); std::unique_ptr createConstPropONNXToONNXPass(); diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index 2c04db251d..116a41d653 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -63,11 +63,13 @@ namespace { // Populated by configureConstPropONNXToONNXPass(). struct ConstPropONNXToONNXPassConfiguration { + static bool roundFPToInt; static int expansionBound; static StringSet<> disabledPatterns; static bool constantPropIsDisabled; }; +bool ConstPropONNXToONNXPassConfiguration::roundFPToInt = false; int ConstPropONNXToONNXPassConfiguration::expansionBound = -1; // -1 == no bound StringSet<> ConstPropONNXToONNXPassConfiguration::disabledPatterns = {}; bool ConstPropONNXToONNXPassConfiguration::constantPropIsDisabled = false; @@ -608,9 +610,9 @@ ElementsAttr getMatMulIntegerMatrixElements( ElementsAttrBuilder &elementsBuilder, Value matrixValue, Value zeroPointValue, function_ref, ElementsAttr)> reshapeZero) { - Type I32 = IntegerType::get(matrixValue.getContext(), 32); + auto I32 = IntegerType::get(matrixValue.getContext(), 32); ElementsAttr matrix8 = getConstValueElements(matrixValue); - ElementsAttr matrix32 = elementsBuilder.castElementType(matrix8, I32); + ElementsAttr matrix32 = elementsBuilder.castToIntElementType(matrix8, I32); if (isNoneValue(zeroPointValue)) { return matrix32; } else { @@ -618,7 +620,7 @@ ElementsAttr getMatMulIntegerMatrixElements( ElementsAttr reshapedZeroPoint8 = reshapeZero(matrix8.getShapedType().getShape(), zeroPoint8); ElementsAttr reshapedZeroPoint32 = - elementsBuilder.castElementType(reshapedZeroPoint8, I32); + elementsBuilder.castToIntElementType(reshapedZeroPoint8, I32); return elementsBuilder.combine(matrix32, reshapedZeroPoint32, matrix32.getShapedType(), subCombiner(I32)); // elementwiseBinaryOpCombiner(I32)); @@ -646,7 +648,7 @@ Value ConstPropGemm(PatternRewriter &rewriter, Value replacingValue, constexpr std::array TRANSPOSE = {1, 0}; ArrayRef permLhs = gemmOp.getTransA() == 0 ? IDENTITY : TRANSPOSE; ArrayRef permRhs = gemmOp.getTransB() == 0 ? IDENTITY : TRANSPOSE; - Type F64 = rewriter.getF64Type(); + FloatType F64 = rewriter.getF64Type(); ShapedType resType = cast(replacingValue.getType()); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ElementsAttr lhs = getConstValueElements(lhsMatrixValue); @@ -655,7 +657,7 @@ Value ConstPropGemm(PatternRewriter &rewriter, Value replacingValue, elementsBuilder.matMul(elementsBuilder.transpose(lhs, permLhs), elementsBuilder.transpose(rhs, permRhs)); if (alpha != 1.0) { - res = elementsBuilder.castElementType(res, F64); + res = elementsBuilder.castToFPElementType(res, F64); res = elementsBuilder.transform(res, F64, [alpha](WideNum n) { return WideNum::widen(alpha * n.narrow()); }); @@ -664,7 +666,7 @@ Value ConstPropGemm(PatternRewriter &rewriter, Value replacingValue, if (hasBias) { ElementsAttr bias = getConstValueElements(biasMatrixValue); if (beta != 1.0) { - bias = elementsBuilder.castElementType(bias, F64); + bias = elementsBuilder.castToFPElementType(bias, F64); bias = elementsBuilder.transform(bias, F64, [beta](WideNum n) { return WideNum::widen(beta * n.narrow()); }); @@ -672,8 +674,8 @@ Value ConstPropGemm(PatternRewriter &rewriter, Value replacingValue, // If one of res or bias has been cast to F64 then also cast the other. if (res.getElementType() != bias.getElementType()) { // One cast is unnecessary but ok: cast to the same type is free. - res = elementsBuilder.castElementType(res, F64); - bias = elementsBuilder.castElementType(bias, F64); + res = elementsBuilder.castToFPElementType(res, F64); + bias = elementsBuilder.castToFPElementType(bias, F64); } // elemType will be F64 if alpha != 1.0 or beta != 1.0. Type elemType = res.getElementType(); @@ -757,39 +759,22 @@ Value ConstPropCast(PatternRewriter &rewriter, Value replacingValue, ElementsAttr constElements = getConstValueElements(constValue); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); - ElementsAttr castElements = - elementsBuilder.castElementType(constElements, toType); - - // 'saturate' is ignored unless toType is a 8 bits float type. - if (saturate.getSInt() != 0 && isa(toType) && - toType.getIntOrFloatBitWidth() == 8) { - float max = - dispatchByBType(btypeOfMlirType(toType), [&](auto btype) -> float { - using cpptype = CppType; - if constexpr (isSmallFPType) { - return cpptype::max; - } else { - llvm_unreachable("unsupported 8 bits floating point type"); - } - }); - // Clipping after cast relies on that cast is lazy and represents - // elements as doubles until they are materialized, so it's not too - // late to clip them here. - // TODO: Clean up the contracts to make it clearer what's going on. - // - // Note that we saturate by clipping which isn't 100% faithful to the - // onnx spec here: https://onnx.ai/onnx/technical/float8.html - // and here: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast - // which, in the case of E4M3FNUZ and E5M2FNUZ, requires infinite values - // to saturate to NaN, whereas we saturate them to lowest/highest with - // clipping. Our clipping implementation matchint the reference - // implementation in onnx/reference/ops/op_cast.py. - // TODO: Change our implementation to match the spec, or change the spec. - WideNum lowest = WideNum::widen(-max); - WideNum highest = WideNum::widen(max); - castElements = elementsBuilder.clip(castElements, lowest, highest); + ElementsAttr castElements; + if (auto ftype = dyn_cast(toType)) { + bool doSaturate = saturate.getSInt() != 0 && ftype.getWidth() == 8; + castElements = + elementsBuilder.castToFPElementType(constElements, ftype, doSaturate); + } else if (auto itype = dyn_cast(toType)) { + // The onnx.Cast spec doesn’t say whether cast from floating point to + // integer type should truncate towards zero or round but past discussions + // (onnx issues #2285, #3776, #5004) point to truncation like numpy. + // But round to nearest, ties to even, is preferable for numerics. + bool round = ConstPropONNXToONNXPassConfiguration::roundFPToInt; + castElements = + elementsBuilder.castToIntElementType(constElements, itype, round); + } else { + llvm_unreachable("cast to unsupported type"); } - return createReplacingConstantOp(rewriter, replacingValue, castElements); } @@ -1102,8 +1087,10 @@ void onnx_mlir::getConstPropONNXToONNXPatterns(RewritePatternSet &patterns) { patterns.insert(patterns.getContext()); } -void onnx_mlir::configureConstPropONNXToONNXPass(int expansionBound, - ArrayRef disabledPatterns, bool constantPropIsDisabled) { +void onnx_mlir::configureConstPropONNXToONNXPass(bool roundFPToInt, + int expansionBound, ArrayRef disabledPatterns, + bool constantPropIsDisabled) { + ConstPropONNXToONNXPassConfiguration::roundFPToInt = roundFPToInt; ConstPropONNXToONNXPassConfiguration::expansionBound = expansionBound; ConstPropONNXToONNXPassConfiguration::disabledPatterns.insert( disabledPatterns.begin(), disabledPatterns.end()); diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index ad55049e59..724fe54824 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1253,6 +1253,21 @@ func.func @test_mul_folding(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf32> { // ----- +func.func @test_cast_i32_i1_i32() -> tensor<4xi32> { + %0 = onnx.Constant dense<[-1, 0, 1, 2]> : tensor<4xi32> + %1 = "onnx.Cast"(%0) {to = i1} : (tensor<4xi32>) -> tensor<4xi1> + %2 = "onnx.Cast"(%1) {to = i32} : (tensor<4xi1>) -> tensor<4xi32> + "onnx.Return"(%2) : (tensor<4xi32>) -> () + + // CHECK-LABEL: func @test_cast_i32_i1_i32 + // CHECK-SAME: () -> tensor<4xi32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 0, 1, 1]> : tensor<4xi32> + // CHECK: onnx.Return [[VAR_0_]] : tensor<4xi32> + // CHECK: } +} + +// ----- + func.func @test_cast_i32_i64() -> tensor<3x2xi64> { %0 = onnx.Constant dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32> %1 = "onnx.Cast"(%0) {to = i64} : (tensor<3x2xi32>) -> tensor<3x2xi64> @@ -1295,15 +1310,17 @@ func.func @test_cast_i32_f32() -> tensor<3x2xf32> { // ----- -func.func @test_cast_f32_i32() -> tensor<3x2xi32> { - %0 = onnx.Constant dense<[[2.3, 3.6], [4.5, 5.5], [6.0, 7.0]]> : tensor<3x2xf32> - %1 = "onnx.Cast"(%0) {to = i32} : (tensor<3x2xf32>) -> tensor<3x2xi32> - "onnx.Return"(%1) : (tensor<3x2xi32>) -> () +func.func @test_cast_f32_i32() -> tensor<8xi32> { + // COM: 0x7F800000/0xFF800000 are +/-INF + // COM: 0x7F800001/0xFFFFFFFF are smallest positive NaN/largest negative NaN + %0 = onnx.Constant dense<[2.3, 3.6, -1.0e10, 1.0e10, 0x7F800000, 0xFF800000, 0x7F800001, 0xFFFFFFFF]> : tensor<8xf32> + %1 = "onnx.Cast"(%0) {to = i32} : (tensor<8xf32>) -> tensor<8xi32> + "onnx.Return"(%1) : (tensor<8xi32>) -> () // CHECK-LABEL: func @test_cast_f32_i32 - // CHECK-SAME: () -> tensor<3x2xi32> { - // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}[2, 3], [4, 5], [6, 7]{{.}}> : tensor<3x2xi32> - // CHECK: onnx.Return [[VAR_0_]] : tensor<3x2xi32> + // CHECK-SAME: () -> tensor<8xi32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, -2147483648, 2147483647, 2147483647, -2147483648, 0, 0]> : tensor<8xi32> + // CHECK: onnx.Return [[VAR_0_]] : tensor<8xi32> // CHECK: } } diff --git a/test/mlir/onnx/onnx_constprop_expansion_bound.mlir b/test/mlir/onnx/onnx_constprop_expansion_bound.mlir deleted file mode 100644 index 8a24809b14..0000000000 --- a/test/mlir/onnx/onnx_constprop_expansion_bound.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s - -//===----------------------------------------------------------------------===// -// Constant propagate ONNXAddOp only if expansion bound satisfied -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @test_add_propagates() -> tensor<2x5xf32> -func.func @test_add_propagates() -> tensor<2x5xf32> { - %0 = "onnx.Constant"() {value = dense<1.0> : tensor<2x1xf32>} : () -> tensor<2x1xf32> - %1 = "onnx.Constant"() {value = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> - %2 = "onnx.Add"(%0, %1) : (tensor<2x1xf32> , tensor<1x5xf32>) -> tensor<2x5xf32> - onnx.Return %2 : tensor<2x5xf32> - // CHECK: onnx.Constant {{.*}} : tensor<2x5xf32> -} - -// CHECK-LABEL: @test_add_doesnt_propagate() -> tensor<5x5xf32> -func.func @test_add_doesnt_propagate() -> tensor<5x5xf32> { - %0 = "onnx.Constant"() {value = dense<1.0> : tensor<5x1xf32>} : () -> tensor<5x1xf32> - %1 = "onnx.Constant"() {value = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> - %2 = "onnx.Add"(%0, %1) : (tensor<5x1xf32> , tensor<1x5xf32>) -> tensor<5x5xf32> - onnx.Return %2 : tensor<5x5xf32> - // CHECK: "onnx.Add"(%0, %1) : (tensor<5x1xf32>, tensor<1x5xf32>) -> tensor<5x5xf32> -} diff --git a/test/mlir/onnx/onnx_constprop_flags.mlir b/test/mlir/onnx/onnx_constprop_flags.mlir new file mode 100644 index 0000000000..0fb4ddcf1d --- /dev/null +++ b/test/mlir/onnx/onnx_constprop_flags.mlir @@ -0,0 +1,40 @@ +// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck --check-prefix=EXPANSIONBOUND2 %s +// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-round-fp-to-int=true %s -split-input-file | FileCheck --check-prefix=ROUND %s +// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-round-fp-to-int=false %s -split-input-file | FileCheck --check-prefix=TRUNCATE %s + +//===----------------------------------------------------------------------===// +// Constant propagate ONNXAddOp only if expansion bound satisfied +//===----------------------------------------------------------------------===// + +func.func @test_add_propagates() -> tensor<2x5xf32> { + %0 = "onnx.Constant"() {value = dense<1.0> : tensor<2x1xf32>} : () -> tensor<2x1xf32> + %1 = "onnx.Constant"() {value = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<2x1xf32> , tensor<1x5xf32>) -> tensor<2x5xf32> + onnx.Return %2 : tensor<2x5xf32> +} +// EXPANSIONBOUND2-LABEL: @test_add_propagates() -> tensor<2x5xf32> +// EXPANSIONBOUND2: onnx.Constant {{.*}} : tensor<2x5xf32> + +// ----- + +func.func @test_add_doesnt_propagate() -> tensor<5x5xf32> { + %0 = "onnx.Constant"() {value = dense<1.0> : tensor<5x1xf32>} : () -> tensor<5x1xf32> + %1 = "onnx.Constant"() {value = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<5x1xf32> , tensor<1x5xf32>) -> tensor<5x5xf32> + onnx.Return %2 : tensor<5x5xf32> +} +// EXPANSIONBOUND2-LABEL: @test_add_doesnt_propagate() -> tensor<5x5xf32> +// EXPANSIONBOUND2: "onnx.Add"(%0, %1) : (tensor<5x1xf32>, tensor<1x5xf32>) -> tensor<5x5xf32> + +// ----- + +func.func @test_cast_f16_i16() -> tensor<6xi16> { + %0 = onnx.Constant dense<[-1.5, -0.5, 0.4, 0.5, 1.5, 1.6]> : tensor<6xf16> + %1 = "onnx.Cast"(%0) {to = i16} : (tensor<6xf16>) -> tensor<6xi16> + onnx.Return %1 : tensor<6xi16> +} +// ROUND-LABEL: @test_cast_f16_i16() -> tensor<6xi16> +// ROUND: onnx.Constant dense<[-2, 0, 0, 0, 2, 2]> : tensor<6xi16> +// +// TRUNCATE-LABEL: @test_cast_f16_i16() -> tensor<6xi16> +// TRUNCATE: onnx.Constant dense<[-1, 0, 0, 0, 1, 1]> : tensor<6xi16> diff --git a/test/unit/SmallFP/TestSmallFP.cpp b/test/unit/SmallFP/TestSmallFP.cpp index c1e2e9171f..7fe4dfbda7 100644 --- a/test/unit/SmallFP/TestSmallFP.cpp +++ b/test/unit/SmallFP/TestSmallFP.cpp @@ -142,6 +142,26 @@ class Test { return 0; } + + template + int test_fp_infinity(const char *fp_name) { + std::cout << "test_fp_no_infinity " << fp_name << ":" << std::endl; + + assert(!llvm::APFloat::getInf(FP::semantics()).isNaN()); + assert(!FP::fromFloat(INFINITY).isNaN()); + + return 0; + } + + template + int test_fp_no_infinity(const char *fp_name) { + std::cout << "test_fp_no_infinity " << fp_name << ":" << std::endl; + + assert(llvm::APFloat::getInf(FP::semantics()).isNaN()); + assert(FP::fromFloat(INFINITY).isNaN()); + + return 0; + } }; template @@ -239,6 +259,13 @@ int main(int argc, char *argv[]) { failures += test.test_fp_equals("float_16", fp16min, fp16max); failures += test.test_fp_equals("bfloat_16", fp16min, fp16max); + failures += test.test_fp_infinity("float_16"); + failures += test.test_fp_infinity("bfloat_16"); + failures += test.test_fp_no_infinity("float_8e4m3fn"); + failures += test.test_fp_no_infinity("float_8e4m3fnuz"); + failures += test.test_fp_infinity("float_8e5m2"); + failures += test.test_fp_no_infinity("float_8e5m2fnuz"); + if (failures != 0) { std::cerr << failures << " test failures\n"; return 1; From faa42c812a168d6b912d06a7951c1e27067e8c2a Mon Sep 17 00:00:00 2001 From: Tong Chen Date: Mon, 30 Oct 2023 16:37:27 -0400 Subject: [PATCH 2/2] Fuse op when the shape is not static (#2577) * transformation Signed-off-by: chentong319 * condition Signed-off-by: chentong319 * test Signed-off-by: chentong319 * format Signed-off-by: chentong319 * more test Signed-off-by: chentong319 * fix Signed-off-by: chentong319 --------- Signed-off-by: chentong319 --- .../ONNXToKrnl/Math/Elementwise.cpp | 139 ++++++++++-------- .../onnx_to_krnl/onnx_lowering_fuse.mlir | 91 ++++++++++++ utils/pre-onnx-mlir.py | 1 + 3 files changed, 169 insertions(+), 62 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 51502b228a..e82a1e4c98 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1580,16 +1580,18 @@ typedef mlir::Value (*EmitScalarFunc)(mlir::ConversionPatternRewriter &rewriter, class OpFusionHelper { public: // Constructor - OpFusionHelper( - mlir::ConversionPatternRewriter &rewriter, mlir::Operation *rootOp) - : rootOp(rootOp), rewriter(rewriter), fusibleOps(), fuseEmitFuctions() {} + OpFusionHelper(mlir::ConversionPatternRewriter &rewriter, + mlir::Operation *rootOp, DimAnalysis *dimAnalysis) + : rootOp(rootOp), rewriter(rewriter), dimAnalysis(dimAnalysis), + fusibleOps(), fuseEmitFuctions() {} // Fusion should not break any control dependence static bool isControlFlowValidForFusion(Operation *useOp, Operation *defOp); // Check whether the inputs of the useOp are valid for useOp to be fused // with the defOp. The defOp defines one of useOp's inputs. - static bool areInputsValidForFusion(Operation *useOp, Operation *defOp); + static bool areInputsValidForFusion( + Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis); // Check whether the op is fusible along the use-def chain from the defOp. // If true, record the op and its scalar op. @@ -1607,13 +1609,18 @@ class OpFusionHelper { // For example, comparison ops, and cast Op. MemRefType getOutputType(MemRefType outputType); - Value emitFuseOps(Value producerResult, ValueRange loopInd = {}); + // Generate the code for the ops to be fused + // procedureResult is the scalar value from producer + // alloc is used to get the tensor for the producer, which is required by + // by the shape helper. + Value emitFuseOps(Value producerResult, Value alloc, ValueRange loopInd = {}); void replaceOrEraseONNXOps(Value alloc); private: mlir::Operation *rootOp; mlir::ConversionPatternRewriter &rewriter; + DimAnalysis *dimAnalysis; llvm::SmallVector fusibleOps; llvm::SmallVector fuseEmitFuctions; }; // End of OpFusionHelper Declaration @@ -1623,10 +1630,11 @@ class OpFusionHelper { template bool enqueueFusibleOpImpl(Operation *useOp, Operation *defOp, SmallVector &fusibleOps, - SmallVector &fuseEmitFunctions) { + SmallVector &fuseEmitFunctions, + DimAnalysis *dimAnalysis) { if (isa(useOp)) { if (OpFusionHelper::isControlFlowValidForFusion(useOp, defOp) && - OpFusionHelper::areInputsValidForFusion(useOp, defOp)) { + OpFusionHelper::areInputsValidForFusion(useOp, defOp, dimAnalysis)) { fusibleOps.emplace_back(useOp); fuseEmitFunctions.emplace_back(emitScalarOpFor); return true; @@ -1639,23 +1647,26 @@ bool enqueueFusibleOpImpl(Operation *useOp, Operation *defOp, template bool enqueueFusibleOp(Operation *useOp, Operation *defOp, SmallVector &fusibleOps, - SmallVector &fuseEmitFunctions); + SmallVector &fuseEmitFunctions, + DimAnalysis *dimAnalysis); template bool enqueueFusibleOp(Operation *useOp, Operation *defOp, SmallVector &fusibleOps, - SmallVector &fuseEmitFunctions) { - if (enqueueFusibleOpImpl(useOp, defOp, fusibleOps, fuseEmitFunctions)) { + SmallVector &fuseEmitFunctions, + DimAnalysis *dimAnalysis) { + if (enqueueFusibleOpImpl( + useOp, defOp, fusibleOps, fuseEmitFunctions, dimAnalysis)) return true; - } else { - return enqueueFusibleOp(useOp, defOp, fusibleOps, fuseEmitFunctions); - } + return enqueueFusibleOp( + useOp, defOp, fusibleOps, fuseEmitFunctions, dimAnalysis); } template <> bool enqueueFusibleOp(Operation *useOp, Operation *defOp, SmallVector &fusibleOps, - SmallVector &fuseEmitFunctions) { + SmallVector &fuseEmitFunctions, + DimAnalysis *dimAnalysis) { return false; } @@ -1689,7 +1700,7 @@ bool OpFusionHelper::checkFusibleOp(Operation *useOp, Operation *defOp, mlir::ONNXAddOp, mlir::ONNXAndOp, mlir::ONNXDivOp, mlir::ONNXMaxOp, mlir::ONNXMeanOp, mlir::ONNXMinOp, mlir::ONNXMulOp, mlir::ONNXOrOp, mlir::ONNXSubOp, mlir::ONNXSumOp, mlir::ONNXXorOp>( - useOp, defOp, fusibleOps, fuseEmitFunctions); + useOp, defOp, fusibleOps, fuseEmitFunctions, dimAnalysis); } // Only operations are in the same block are allowed to fuse. @@ -1735,29 +1746,20 @@ bool OpFusionHelper::isControlFlowValidForFusion( // assumed the canonicalization has hoisted all constant to the beginning of the // function by fold function. bool OpFusionHelper::areInputsValidForFusion( - Operation *useOp, Operation *defOp) { + Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis) { // Elementwise unary operation is always fusible if (useOp->getOperands().size() == 1) return true; - // To fuse Elementwise op with more one operands with the producer, - // the shape of the output the user Op has to have the same size - // output as that of the producer Op. Here dimension expansion with size - // 1 is allowed. Refer to hasNoBroadcast() definition. - // ToFix: This PR simply check static shape and does not use symbolic - // shape analysis and BroadcastShapeHelper - // Some discussion can be found at - // https://github.com/onnx/onnx-mlir/issues/2199 - - if (!hasStaticShape(defOp->getResults()[0].getType())) - return false; - - ArrayRef defShape = getShape(defOp->getResults()[0].getType()); - ArrayRef useShape = getShape(useOp->getResults()[0].getType()); + Type defOutputType = defOp->getResultTypes()[0]; + Type useOutputType = useOp->getResultTypes()[0]; + ArrayRef defShape = getShape(defOutputType); + ArrayRef useShape = getShape(useOutputType); if (defShape != useShape) { return false; } + // Check the inputs in the useOp for (size_t i = 0; i < useOp->getOperands().size(); i++) { // Only input from block argument and constant is allowed, // if the input does not come from the defining Op @@ -1769,12 +1771,29 @@ bool OpFusionHelper::areInputsValidForFusion( return false; } } + } + + // Check whether this shape of the defOp is the same as the shape of + // the output of use op. If true, the iteration space from the defOp is + // sufficient for the element-wise operation for the useOp, even if + // MDBroadcast occurs in the useOp. + // Otherwise, the loop nest should be defined according to the tensor with + // larger space. + + // First check the rank + if (getRank(defOutputType) != getRank(useOutputType)) + return false; - // ToFix: This restriction can be relaxed if ShapeHelper utility is used - // to generate load in future. - if (!hasStaticShape(useOp->getOperand(i).getType())) + if (dimAnalysis) { + if (!dimAnalysis->sameShape(defOp->getResult(0), useOp->getResult(0))) return false; - ArrayRef inputShape = getShape(useOp->getOperand(i).getType()); + } else { + // If there is no dimAnalysis, check the simplest case. + // Static and the same shape + if (!hasStaticShape(useOutputType)) + return false; + + ArrayRef inputShape = getShape(useOutputType); if (inputShape != defShape) return false; } @@ -1819,7 +1838,8 @@ MemRefType OpFusionHelper::getOutputType(MemRefType outputType) { } // Emit fusion Ops -Value OpFusionHelper::emitFuseOps(Value defOpResult, ValueRange loopInd) { +Value OpFusionHelper::emitFuseOps( + Value defOpResult, Value alloc, ValueRange loopInd) { if (isFusibleListEmpty()) return defOpResult; @@ -1835,33 +1855,30 @@ Value OpFusionHelper::emitFuseOps(Value defOpResult, ValueRange loopInd) { MDBuilder create(rewriter, loc); Type currentElementType = getElementType(useOp->getResults()[0].getType()); - // Prepare Values for EmitScalarOpFor - SmallVector inputValues; - // ToFix: expect to use new utility for this purpose - // There is an issue to fix: cannot getRemappedValue for the Value that is - // currently handling: the defOp. - // Otherwise, runtime error: "null operand found" caused by - // just calling the function without using the result! -#if 0 + // useOperands is used for ShapeHelper and load op. + // getRemappedValue is needed for load op. SmallVector useOperands; for (auto oper : useOp->getOperands()) { if (oper.getDefiningOp() != defOp) useOperands.emplace_back(rewriter.getRemappedValue(oper)); + else + // load will not needed because of useOpResult. + // This value is only needed by shape helper. + useOperands.emplace_back(alloc); } - LogicalResult res = - rewriter.getRemappedValues(useOp->getOperands(), useOperands); - assert(succeeded(res) && "Could not remap value for rewriter"); + // Use shape helper to generate load index ONNXBroadcastOpShapeHelper shapeHelper( useOp, useOperands, &create.krnlIE, nullptr, false); -#endif + shapeHelper.computeShapeAndAssertOnFailure(); + + // Prepare Values for EmitScalarOpFor + SmallVector inputValues; for (size_t i = 0; i < useOp->getOperands().size(); i++) { Value inputValue = useOp->getOperand(i); Operation *inputOp = inputValue.getDefiningOp(); if (inputOp == defOp) { inputValues.emplace_back(defOpResult); } else { - // ToFix: expect to use new utility to handle any broadcast cases -#if 0 IndexExprScope innerScope(create.krnl, shapeHelper.getScope()); SmallVector outputAccessExprs; getIndexExprList(loopInd, outputAccessExprs); @@ -1870,16 +1887,13 @@ Value OpFusionHelper::emitFuseOps(Value defOpResult, ValueRange loopInd) { inputValue, i, outputAccessExprs, loadAccessExprs, true); assert(succeeded(res) && "Could not compute access indices"); Value load = create.krnl.loadIE(useOperands[i], loadAccessExprs); -#endif - // The shape is guaranteed to be the same. - Value load = - create.krnl.load(rewriter.getRemappedValue(inputValue), loopInd); inputValues.emplace_back(load); } } defOpResult = emitScalar(rewriter, loc, useOp, currentElementType, inputValues); defOp = useOp; + alloc = defOp->getResult(0); } return defOpResult; } @@ -1997,7 +2011,7 @@ struct ONNXElementwiseUnaryOpLowering LLVM_DEBUG(llvm::dbgs() << " scalar execution\n"); // Try to fuse the unary elementwise consumers - OpFusionHelper opFusionHelper(rewriter, op); + OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis); opFusionHelper.findFusibleOps(); outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); @@ -2034,7 +2048,7 @@ struct ONNXElementwiseUnaryOpLowering auto loweredOpResult = emitScalarOpFor( rewriter, loc, op, elementType, args); loweredOpResult = - opFusionHelper.emitFuseOps(loweredOpResult, loopInd); + opFusionHelper.emitFuseOps(loweredOpResult, alloc, loopInd); // Store result in the resulting array. createKrnl.store(loweredOpResult, alloc, loopInd); }); @@ -2055,7 +2069,7 @@ struct ONNXElementwiseUnaryOpLowering } auto loweredOpResult = emitScalarOpFor( rewriter, loc, op, elementType, args); - loweredOpResult = opFusionHelper.emitFuseOps(loweredOpResult); + loweredOpResult = opFusionHelper.emitFuseOps(loweredOpResult, alloc); // Store result in the resulting array. create.krnl.store(loweredOpResult, alloc); } @@ -2165,7 +2179,7 @@ struct ONNXElementwiseBinaryOpLowering LLVM_DEBUG(llvm::dbgs() << " scalar execution\n"); // Try to fuse the unary elementwise consumers - OpFusionHelper opFusionHelper(rewriter, op); + OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis); opFusionHelper.findFusibleOps(); outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); @@ -2209,7 +2223,7 @@ struct ONNXElementwiseBinaryOpLowering Value result = emitScalarOpFor( rewriter, loc, op, outputElementType, {lhs, rhs}); - result = opFusionHelper.emitFuseOps(result, loopInd); + result = opFusionHelper.emitFuseOps(result, alloc, loopInd); // Store result in the resulting array. createKrnl.store(result, alloc, loopInd); }); @@ -2221,7 +2235,7 @@ struct ONNXElementwiseBinaryOpLowering Value result = emitScalarOpFor( rewriter, loc, op, outputElementType, {lhs, rhs}); - result = opFusionHelper.emitFuseOps(result); + result = opFusionHelper.emitFuseOps(result, alloc); // Store result in the resulting array. create.krnl.store(result, alloc); } @@ -2328,7 +2342,7 @@ struct ONNXElementwiseVariadicOpLowering LLVM_DEBUG(llvm::dbgs() << " scalar execution\n"); // Try to fuse the unary elementwise consumers - OpFusionHelper opFusionHelper(rewriter, op); + OpFusionHelper opFusionHelper(rewriter, op, dimAnalysis); opFusionHelper.findFusibleOps(); outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); @@ -2378,7 +2392,8 @@ struct ONNXElementwiseVariadicOpLowering Value finalResult = emitPostProcessingFor( rewriter, loc, op, outputElementType, accumulated); - finalResult = opFusionHelper.emitFuseOps(finalResult, loopInd); + finalResult = + opFusionHelper.emitFuseOps(finalResult, alloc, loopInd); // Store result in the resulting array. createKrnl.storeIE(finalResult, alloc, outputAccessExprs); }); @@ -2395,7 +2410,7 @@ struct ONNXElementwiseVariadicOpLowering } Value finalResult = emitPostProcessingFor( rewriter, loc, op, outputElementType, accumulated); - finalResult = opFusionHelper.emitFuseOps(finalResult); + finalResult = opFusionHelper.emitFuseOps(finalResult, alloc); // Store result in the resulting array. create.krnl.store(finalResult, alloc); } diff --git a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir index 4f46049226..65fccaac3d 100644 --- a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir +++ b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir @@ -178,3 +178,94 @@ func.func @fuse_element_14(%arg0: tensor<5xf32>) -> tensor<*xf32> { // CHECK: return [[RES_]] : memref<5xf32> // CHECK: } } + +// ----- + +func.func @fuse_element_15(%arg0: tensor<4x5xf32>, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor<4x5xf32>) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %1 : tensor<*xf32> +} +// CHECK-LABEL: func.func @fuse_element_15 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4x5xf32>, [[PARAM_1_:%.+]]: memref) -> memref<4x5xf32> { +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4x5xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 4, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 5){ +// CHECK: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<4x5xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = math.sqrt [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]]#1] : memref +// CHECK: [[VAR_5_:%.+]] = arith.addf [[VAR_3_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<4x5xf32> +// CHECK: } +// CHECK: return [[RES_]] : memref<4x5xf32> +// CHECK: } + +// ----- + +func.func @fuse_element_16(%arg0: tensor<4x?xf32>, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor<4x?xf32>) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %1 : tensor<*xf32> +} +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0, s1] -> (s1, s0)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-LABEL: func.func @fuse_element_16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4x?xf32>, [[PARAM_1_:%.+]]: memref) -> memref<4x?xf32> { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<4x?xf32> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref<4x?xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<4x?xf32> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 4, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]]([[VAR_dim_0_]])){ +// CHECK: [[VAR_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1] : memref<4x?xf32> +// CHECK: [[VAR_5_:%.+]] = math.sqrt [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_5_]], [[RES_]]{{.}}[[VAR_3_]]#0, [[VAR_3_]]#1] : memref<4x?xf32> +// CHECK: } +// CHECK: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_1_]], [[CST_0_]] : memref +// CHECK: [[VAR_1_:%.+]] = affine.max [[MAP_1_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_1] +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc([[VAR_1_]]) {{.*}}: memref<4x?xf32> +// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to 4, [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_]], [[VAR_dim_]]_1, [[VAR_1_]])){ +// CHECK-DAG: [[VAR_3_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = arith.cmpi sgt, [[VAR_dim_]], [[CST_1_]] : index +// CHECK: [[VAR_5_1_:%.+]] = arith.select [[LOAD_PARAM_0_MEM_1_]], [[VAR_3_1_]]#1, [[CST_0_]] : index +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_3_1_]]#0, [[VAR_5_1_]]{{.}} : memref<4x?xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpi sgt, [[VAR_dim_1_]], [[CST_1_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_3_1_]]#1, [[CST_0_]] : index +// CHECK: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_8_]]{{.}} : memref +// CHECK: [[VAR_10_:%.+]] = arith.addf [[LOAD_RES_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_10_]], [[RES_1_]]{{.}}[[VAR_3_1_]]#0, [[VAR_3_1_]]#1] : memref<4x?xf32> +// CHECK: } +// CHECK: return [[RES_1_]] : memref<4x?xf32> +// CHECK: } + +// ----- + +func.func @fuse_element_17(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %1 : tensor<*xf32> +} +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @fuse_element_17 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref) -> memref { +// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[MAP_0_]]([[VAR_dim_0_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 5){ +// CHECK: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref +// CHECK-DAG: [[VAR_3_:%.+]] = math.sqrt [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]]#1] : memref +// CHECK: [[VAR_5_:%.+]] = arith.addf [[VAR_3_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref +// CHECK: } +// CHECK: return [[RES_]] : memref +// CHECK: } + diff --git a/utils/pre-onnx-mlir.py b/utils/pre-onnx-mlir.py index 009e1e90e2..d344c2c741 100644 --- a/utils/pre-onnx-mlir.py +++ b/utils/pre-onnx-mlir.py @@ -15,6 +15,7 @@ import argparse from onnx import version_converter, helper +print("your onnx package version is " + str(onnx.__version__)) parser = argparse.ArgumentParser() parser.add_argument("model", help="onnx model") parser.add_argument("--save", help="save the converted model", action="store_true")