From a7bee2c84f42d9ee23f007b427b1b40501fe2ff6 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 29 May 2023 22:54:40 -0400 Subject: [PATCH] [StableHLO][CHLO] Port CHLO decomposition patterns These are ported from the mlir-hlo project. For more context, see the initial import: https://github.com/openxla/iree/pull/12957. The biggest difference it the removal of most FileCheck CHECK lines in tests. MHLO hardcoded thousands lines of exact decomposition sequences that fell apart after due to different canonicalizations and folds. Without a script to regenerate these CHECKs, these tests were not maintanable and I decided to drop them. Now we only check that the dialect conversion succeeded. Other notable differences to the MHLO implementation: - Ported some utility functions and tablegen defs. - New `chlo.tan` lowering, since StableHLO does not provide a tan op. Issue: https://github.com/openxla/iree/issues/13803 --- .../InputConversion/StableHLO/BUILD.bazel | 18 + .../StableHLO/CHLODecompositionPatterns.td | 371 ++++ .../InputConversion/StableHLO/CMakeLists.txt | 14 + .../StableHLO/LegalizeCHLO.cpp | 1823 ++++++++++++++++- .../StableHLO/test/BUILD.bazel | 1 + .../StableHLO/test/CMakeLists.txt | 1 + .../test/legalize_chlo_decomposition.mlir | 507 +++++ .../test/legalize_chlo_no_broadcast.mlir | 46 +- .../test/legalize_chlo_with_broadcast.mlir | 1 - 9 files changed, 2730 insertions(+), 52 deletions(-) create mode 100644 compiler/src/iree/compiler/InputConversion/StableHLO/CHLODecompositionPatterns.td create mode 100644 compiler/src/iree/compiler/InputConversion/StableHLO/test/legalize_chlo_decomposition.mlir diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel index 666d1f004990..a6b533afeb17 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel @@ -42,6 +42,23 @@ iree_compiler_cc_library( ], ) +iree_gentbl_cc_library( + name = "CHLODecompositionPatterns", + tbl_outs = [ + ( + ["--gen-rewriters"], + "CHLODecompositionPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "CHLODecompositionPatterns.td", + deps = [ + "@llvm-project//mlir:ShapeTdFiles", + "@mlir-hlo//stablehlo:chlo_ops_td_files", + "@mlir-hlo//stablehlo:stablehlo_ops_td_files", + ], +) + iree_compiler_cc_library( name = "StableHLOLegalization", srcs = [ @@ -66,6 +83,7 @@ iree_compiler_cc_library( "VerifyCompilerInputLegality.cpp", ], deps = [ + ":CHLODecompositionPatterns", ":PassHeaders", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CHLODecompositionPatterns.td b/compiler/src/iree/compiler/InputConversion/StableHLO/CHLODecompositionPatterns.td new file mode 100644 index 000000000000..fe852185b44d --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CHLODecompositionPatterns.td @@ -0,0 +1,371 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// This is the legalization pattern definition file for CHLO to MHLO. +// These are included in the PopulateDecomposeChloPatterns factory +// and should only include canonical expansions which are not actually +// ambiguous/different for various backends. Avoid patterns that are actually +// lowering to non-canonical forms. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "stablehlo/dialect/ChloOps.td" +include "stablehlo/dialect/StablehloOps.td" + +class StableHLO_ComparisonDirectionValue : + ConstantAttr; + +class ConstantLike : NativeCodeCall< + "::mlir::iree_compiler::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, + "Complex element type">; + +def NonComplexElementType : Type< + CPred<"!isa(cast($_self).getElementType())">, + "Non-complex element type">; + +def ConstantLikeMaxFiniteValue : NativeCodeCall< + "::mlir::iree_compiler::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + +def ConstantLikePosInfValue : NativeCodeCall< + "::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; + +def ConstantLikeNegInfValue : NativeCodeCall< + "::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">; + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +// Expand acos for non-complex arguments to MHLO dialect as follows: +// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 +// = pi if x == -1 +// +// TODO(b/237376133): Support operands with complex element types separately +// using the following formula. +// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) +def : Pat<(CHLO_AcosOp NonComplexElementType:$input), + (StableHLO_SelectOp + (StableHLO_CompareOp + $input, + (ConstantLike<"-1"> $input), + StableHLO_ComparisonDirectionValue<"NE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + ), + (StableHLO_MulOp + (ConstantLike<"2"> $input), + (StableHLO_Atan2Op + (StableHLO_SqrtOp + (StableHLO_SubtractOp + (ConstantLike<"1"> $input), + (StableHLO_MulOp $input, $input) + ) + ), + (StableHLO_AddOp + (ConstantLike<"1"> $input), + $input + ) + ) + ), + (ConstantLike<"M_PI"> $input) + )>; + +// Expand acosh to MHLO dialect as follows: +// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 +// = log(x + sqrt((x+1)*(x-1))) +// acosh(x) = nan if x < -1 +// +// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as +// log(2*x) = log(2) + log(x). (Note this works because negative x never +// overflows; x < -1 simply yields nan. +def : Pat<(CHLO_AcoshOp NonComplexElementType:$input), + (StableHLO_SelectOp + (StableHLO_CompareOp + $input, + (ConstantLike<"-1"> $input), + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + ), + (ConstantLike<"NAN"> $input), + (StableHLO_SelectOp + (StableHLO_CompareOp + $input, + (StableHLO_SqrtOp + (ConstantLikeMaxFiniteValue $input) + ), + StableHLO_ComparisonDirectionValue<"GE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + ), + (StableHLO_AddOp + (StableHLO_LogOp $input), + (StableHLO_LogOp + (ConstantLike<"2"> $input) + ) + ), + (StableHLO_LogOp + (StableHLO_AddOp + $input, + (StableHLO_SqrtOp + (StableHLO_MulOp + (StableHLO_AddOp + (ConstantLike<"1"> $input), + $input + ), + (StableHLO_AddOp + (ConstantLike<"-1"> $input), + $input + ) + ) + ) + ) + ) + ) + )>; + +// Expand acosh for complex arguments to MHLO dialect as +// acosh(x) = log(x + sqrt((x+1)*(x-1))) +// +// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: +// "For now, we ignore the question of overflow if x is a +// complex type, because we don't yet have exhaustive tests for complex trig +// functions". +def : Pat<(CHLO_AcoshOp ComplexElementType:$input), + (StableHLO_LogOp + (StableHLO_AddOp + $input, + (StableHLO_SqrtOp + (StableHLO_MulOp + (StableHLO_AddOp + $input, + (ConstantLike<"1"> $input) + ), + (StableHLO_SubtractOp + $input, + (ConstantLike<"1"> $input) + ) + ) + ) + ) + )>; + + +// Expand asin to MHLO dialect as follows: +// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +def : Pat<(CHLO_AsinOp $input), + (StableHLO_MulOp + (ConstantLike<"2"> $input), + (StableHLO_Atan2Op + $input, + (StableHLO_AddOp + (ConstantLike<"1"> $input), + (StableHLO_SqrtOp + (StableHLO_SubtractOp + (ConstantLike<"1"> $input), + (StableHLO_MulOp $input, $input) + ) + ) + ) + ) + )>; + +// Expand asinh for non-complex arguments to MHLO dialect as +// asinh(x) = log(x + sqrt(x^2 + 1)) +// +// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) +// as 2*x and return log(2) + log(x). +// +// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point +// arithmetic. However, we would like to retain the low order term of this, +// which is around 0.5 * x^2 using a binomial expansion. +// Let z = sqrt(a^2 + 1) +// The following rewrite retains the lower order term. +// log(a + sqrt(a^2 + 1)) +// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) +// = log((a + a^2 + 1 + a * z + z) / (1 + z)) +// = log(1 + a + a^2 / (1 + z)) +// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) +// +// If x is negative, the above would give us some trouble; we can't approximate +// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = +// -asinh(x). +def : Pat<(CHLO_AsinhOp NonComplexElementType:$input), + (StableHLO_MulOp + (StableHLO_SignOp $input), + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_AbsOp $input), + (StableHLO_SqrtOp + (ConstantLikeMaxFiniteValue $input) + ), + StableHLO_ComparisonDirectionValue<"GE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + ), + (StableHLO_AddOp + (StableHLO_LogOp + (StableHLO_AbsOp $input) + ), + (StableHLO_LogOp + (ConstantLike<"2"> $input) + ) + ), + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_AbsOp $input), + (ConstantLike<"1"> $input), + StableHLO_ComparisonDirectionValue<"LE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + ), + (StableHLO_Log1pOp + (StableHLO_AddOp + (StableHLO_AbsOp $input), + (StableHLO_MulOp + (StableHLO_AbsOp $input), + (StableHLO_DivOp + (StableHLO_AbsOp $input), + (StableHLO_AddOp + (ConstantLike<"1"> $input), + (StableHLO_SqrtOp + (StableHLO_AddOp + (StableHLO_MulOp + (StableHLO_AbsOp $input), + (StableHLO_AbsOp $input) + ), + (ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + ), + (StableHLO_LogOp + (StableHLO_AddOp + (StableHLO_AbsOp $input), + (StableHLO_SqrtOp + (StableHLO_AddOp + (StableHLO_MulOp + (StableHLO_AbsOp $input), + (StableHLO_AbsOp $input) + ), + (ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + )>; + +// Expand asinh for complex arguments to MHLO dialect as +// asinh(x) = log(x + sqrt(x^2 + 1)) +// +// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: +// "For now, we ignore the question of overflow if x is a +// complex type, because we don't yet have exhaustive tests for complex trig +// functions". +def : Pat<(CHLO_AsinhOp ComplexElementType:$input), + (StableHLO_LogOp + (StableHLO_AddOp + $input, + (StableHLO_SqrtOp + (StableHLO_AddOp + (StableHLO_MulOp $input, $input), + (ConstantLike<"1"> $input) + ) + ) + ) + )>; + +// Express `atan` as +// atan(x) = atan2(x, 1) +def : Pat<(CHLO_AtanOp $input), + (StableHLO_Atan2Op + $input, + (ConstantLike<"1"> $input) + )>; + +// Express `atanh` for non-complex arguments as follows: +// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 +// atanh(x) = nan otherwise +def : Pat<(CHLO_AtanhOp NonComplexElementType:$input), + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_AbsOp $input), + (ConstantLike<"1"> $input), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + ), + (ConstantLike<"NAN"> $input), + (StableHLO_MulOp + (StableHLO_SubtractOp + (StableHLO_Log1pOp $input), + (StableHLO_Log1pOp + (StableHLO_NegOp $input) + ) + ), + (ConstantLike<"0.5"> $input) + ) + )>; + +// Express `atanh` for complex arguments as follows: +// atanh(x) = (log(1 + x) - log(1 + (-x))) * 0.5 +// +// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: +// "For now, we ignore the nan edge case for complex inputs, +// because we don't yet have exhaustive tests for complex trig functions". +def : Pat<(CHLO_AtanhOp ComplexElementType:$input), + (StableHLO_MulOp + (StableHLO_SubtractOp + (StableHLO_Log1pOp $input), + (StableHLO_Log1pOp + (StableHLO_NegOp $input) + ) + ), + (ConstantLike<"0.5"> $input) + )>; + +// Express `conj` as +// conj(x) = (re(x), -im(x)). +def : Pat<(CHLO_ConjOp $v), + (StableHLO_ComplexOp (StableHLO_RealOp $v), (StableHLO_NegOp (StableHLO_ImagOp $v)))>; + +// Express `is_inf` as +// is_inf(x) = is_pos_inf(|x|) +def : Pat<(CHLO_IsInfOp NonComplexElementType:$input), + (CHLO_IsPosInfOp + (StableHLO_AbsOp $input) + )>; + +// Express `is_pos_inf` as +// is_pos_inf(x) = (x == +inf) +def : Pat<(CHLO_IsPosInfOp NonComplexElementType:$input), + (StableHLO_CompareOp + $input, + (ConstantLikePosInfValue $input), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + )>; + +// Express `is_neg_inf` as +// is_neg_inf(x) = (x == -inf) +def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input), + (StableHLO_CompareOp + $input, + (ConstantLikeNegInfValue $input), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + )>; + +// Express `tan` as +// sine(x) / cosine(x) +def : Pat<(CHLO_TanOp NonComplexElementType:$input), + (StableHLO_DivOp + (StableHLO_SineOp $input), + (StableHLO_CosineOp $input) + )>; diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt index a4b3bc672230..26c54f4c25c4 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt @@ -1,3 +1,7 @@ +# Add this tablegen include to support CHLO rewrites with DRR. +list(APPEND IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${IREE_SOURCE_DIR}/third_party/mlir-hlo/stablehlo") + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_ABOVE_THIS_LINE ### ################################################################################ # Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # # compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel # @@ -34,6 +38,15 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + CHLODecompositionPatterns + TD_FILE + "CHLODecompositionPatterns.td" + OUTS + --gen-rewriters CHLODecompositionPatterns.h.inc +) + iree_cc_library( NAME StableHLOLegalization @@ -58,6 +71,7 @@ iree_cc_library( "TypeConversion.h" "VerifyCompilerInputLegality.cpp" DEPS + ::CHLODecompositionPatterns ::PassHeaders ChloOps IREELinalgExtDialect diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp index f335483545c5..81941b286e16 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp @@ -4,7 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops. +// Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops, +// taking care of CHLO's broadcasting semantics #include "iree/compiler/InputConversion/StableHLO/Passes.h" #include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h" @@ -16,6 +17,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -37,7 +39,7 @@ namespace { template struct HloNaryElementwiseAdaptor { static ToOpTy createOp(FromOpTy fromOp, Type resultType, - ValueRange broadcastedOperands, OpBuilder& builder) { + ValueRange broadcastedOperands, OpBuilder &builder) { return builder.create(fromOp.getLoc(), resultType, broadcastedOperands); } @@ -82,21 +84,21 @@ static std::optional toStableHloComparisonType( struct HloCompareAdaptor { static mlir::stablehlo::CompareOp createOp( mlir::chlo::BroadcastCompareOp fromOp, Type resultType, - ValueRange broadcastedOperands, OpBuilder& builder) { + ValueRange broadcastedOperands, OpBuilder &builder) { auto chloDirection = fromOp.getComparisonDirection(); - auto mhloDirection = toStableHloComparisonDirection(chloDirection); - if (!mhloDirection) return nullptr; + auto hloDirection = toStableHloComparisonDirection(chloDirection); + if (!hloDirection) return nullptr; auto chloType = fromOp.getCompareType().value_or(mlir::chlo::ComparisonType::NOTYPE); - auto mhloType = toStableHloComparisonType(chloType); - if (!mhloType) return nullptr; - auto mhloTypeAttr = fromOp.getCompareType() - ? mlir::stablehlo::ComparisonTypeAttr::get( - builder.getContext(), *mhloType) - : nullptr; + auto hloType = toStableHloComparisonType(chloType); + if (!hloType) return nullptr; + auto hloTypeAttr = fromOp.getCompareType() + ? mlir::stablehlo::ComparisonTypeAttr::get( + builder.getContext(), *hloType) + : nullptr; return builder.create( fromOp.getLoc(), resultType, broadcastedOperands[0], - broadcastedOperands[1], *mhloDirection, mhloTypeAttr); + broadcastedOperands[1], *hloDirection, hloTypeAttr); } }; @@ -104,9 +106,9 @@ struct HloCompareAdaptor { // to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values. template