Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@20255865
Browse files Browse the repository at this point in the history
Other than the usual integration, the CL does two things:

The upstream change openxla/stablehlo#1869 in StableHLO updates various API related to shape inference. MHLO shape inference functions uses those APIs. The CL fixes the invocation of those APIs in MHLO codebase so as to sync the semantics of StableHLO reduction operation with MHLO.

There exists canonicalization passes like group-reduction-dimensions and hlo-canonicalize-reduction which create reduce operation using builder methods that calls type inference of reduce op with empty reduction region example. This is problematic as, with the change, the type inference of reduce op is now dependent on the reduction body. The CL updates all the calls sites of the problematic builder (the one which calls type inference with empty reduction block) with the invocation of a new custom builder method introduced for mhlo::Reduce operation.

Note that at the moment we do not need similar custom builder for other reduction based operations (like scatter, reduce_scatter, all_reduce, select_and_scatter, reduce_window) as they are presently created using a builder version take result type as an input and hence does not call inference from within.

Also, the CL adds verification tests for the operations with promotable semantics.

PiperOrigin-RevId: 599930190
  • Loading branch information
sdasgup3 authored and TensorFlow MLIR Team committed Jan 19, 2024
1 parent 48f90ae commit 98f4618
Show file tree
Hide file tree
Showing 33 changed files with 487 additions and 470 deletions.
351 changes: 70 additions & 281 deletions mhlo/IR/hlo_ops.cc

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,6 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]>
}

def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [
SameOperandsAndResultElementType,
SingleBlockImplicitTerminator<"ReturnOp">,
InferTensorType
]> {
Expand Down Expand Up @@ -1544,8 +1543,7 @@ def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [
let hasCustomHLOConverter = 1;
}

def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter",
[SameOperandsAndResultElementType]> {
def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", []> {
let summary = "ReduceScatter operation";
let description = [{
Within each process group in the process grid, performs reduction, using
Expand Down Expand Up @@ -1691,6 +1689,12 @@ def MHLO_ReduceOp: MHLO_ShapedInterfaceOp<"reduce", [
// compatible with reduce op's operands.
let regions = (region SizedRegion<1>:$body);

// Builder
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values,
"DenseIntElementsAttr":$dimensions, "TypeRange":$element_types)>,
];

// TODO(b/129422361): ReduceOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -224,8 +225,11 @@ LogicalResult tryLowerTo1DOr2DReduction(
int64_t reductionDim = leadingReduction ? 0 : 1;
auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim});
Value initVal = op.getInitValues().front();
auto reductionOp =
rewriter.create<ReduceOp>(loc, intermResult, initVal, reductionDimAttr);
SmallVector<Type> elementTypes{llvm::map_range(
op.getBody().front().getTerminator()->getOperands(),
[](Value v) { return v.getType().cast<ShapedType>().getElementType(); })};
auto reductionOp = rewriter.create<ReduceOp>(loc, intermResult, initVal,
reductionDimAttr, elementTypes);
rewriter.inlineRegionBefore(op.getBody(), reductionOp.getBody(),
reductionOp.getBody().begin());
intermResult = reductionOp->getResults().front();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <memory>

#include "llvm/ADT/STLExtras.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -219,8 +220,12 @@ struct HloCanonicalizeReductionPass
elemTy),
operand, newOperandShape));
}
auto newOp =
b.create<ReduceOp>(loc, newOperands, op.getInitValues(), attr);
SmallVector<Type> elementTypes{llvm::map_range(
op.getBody().front().getTerminator()->getOperands(), [](Value v) {
return v.getType().cast<ShapedType>().getElementType();
})};
auto newOp = b.create<ReduceOp>(loc, newOperands, op.getInitValues(),
attr, elementTypes);
newOp.getBody().takeBody(op.getBody());

SmallVector<Value, 4> newResults;
Expand Down
1 change: 0 additions & 1 deletion stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ cc_library(
hdrs = [
"stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.h",
"stablehlo/conversions/linalg/transforms/MapStablehloToScalarOp.h",
"stablehlo/conversions/linalg/transforms/PassDetail.h",
"stablehlo/conversions/linalg/transforms/Passes.h",
"stablehlo/conversions/linalg/transforms/Rewriters.h",
"stablehlo/conversions/linalg/transforms/TypeConversion.h",
Expand Down
1 change: 0 additions & 1 deletion stablehlo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ cc_library(
hdrs = [
"stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.h",
"stablehlo/conversions/linalg/transforms/MapStablehloToScalarOp.h",
"stablehlo/conversions/linalg/transforms/PassDetail.h",
"stablehlo/conversions/linalg/transforms/Passes.h",
"stablehlo/conversions/linalg/transforms/Rewriters.h",
"stablehlo/conversions/linalg/transforms/TypeConversion.h",
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ http_archive(
],
)

LLVM_COMMIT = "baba0a4cb43181a78881fce683e3a5016daa8ce6"
LLVM_COMMIT = "3a82a1c3f6bdc9259cc4641f66fc76d1e171e382"

LLVM_SHA256 = "a81c8c08b7fc11a9668b2ed3e37a3e98ad8f9e4e4f6ba2c8b0b36e105a775d4e"
LLVM_SHA256 = "c525cdb14bb239695852d696bcd13a6d47e579be18386ba2048515fe7f059153"

http_archive(
name = "llvm-raw",
Expand Down
13 changes: 9 additions & 4 deletions stablehlo/build_tools/github_actions/ci_build_cmake.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,25 @@ fi

LLVM_BUILD_DIR="$1"
STABLEHLO_BUILD_DIR="$2"
CMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE:-RelWithDebInfo}"

# Turn on building Python bindings
STABLEHLO_ENABLE_BINDINGS_PYTHON="${STABLEHLO_ENABLE_BINDINGS_PYTHON:-OFF}"

# Configure StableHLO
cmake -GNinja \
-B"$STABLEHLO_BUILD_DIR" \
-DLLVM_ENABLE_LLD=ON \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=On \
-DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR="$LLVM_BUILD_DIR/lib/cmake/mlir" \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DSTABLEHLO_ENABLE_STRICT_BUILD=On
-DSTABLEHLO_ENABLE_STRICT_BUILD=ON \
-DSTABLEHLO_ENABLE_BINDINGS_PYTHON="${STABLEHLO_ENABLE_BINDINGS_PYTHON}"

# Build and Test StableHLO
cd "$STABLEHLO_BUILD_DIR"
cd "$STABLEHLO_BUILD_DIR" || exit
ninja check-stablehlo-ci
7 changes: 6 additions & 1 deletion stablehlo/build_tools/github_actions/ci_build_cmake_llvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ fi
LLVM_SRC_DIR="$1"
LLVM_BUILD_DIR="$2"

CMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE:-RelWithDebInfo}"
# Turn on building Python bindings
MLIR_ENABLE_BINDINGS_PYTHON="${MLIR_ENABLE_BINDINGS_PYTHON:-OFF}"

# Configure LLVM
cmake -GNinja \
"-H$LLVM_SRC_DIR/llvm" \
Expand All @@ -34,10 +38,11 @@ cmake -GNinja \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_TARGETS_TO_BUILD=host \
-DLLVM_INCLUDE_TOOLS=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON="${MLIR_ENABLE_BINDINGS_PYTHON}" \
-DLLVM_ENABLE_BINDINGS=OFF \
-DLLVM_BUILD_TOOLS=OFF \
-DLLVM_INCLUDE_TESTS=OFF \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" \
-DLLVM_ENABLE_ASSERTIONS=On \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
Expand Down
44 changes: 0 additions & 44 deletions stablehlo/build_tools/github_actions/ci_build_cmake_python_api.sh

This file was deleted.

2 changes: 1 addition & 1 deletion stablehlo/docs/_toc.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The StableHLO Authors.
# Copyright 2024 The StableHLO Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
13 changes: 13 additions & 0 deletions stablehlo/stablehlo/conversions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
# Copyright 2024 The StableHLO Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(linalg)
add_subdirectory(tosa)
13 changes: 13 additions & 0 deletions stablehlo/stablehlo/conversions/linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
# Copyright 2024 The StableHLO Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(tests)
add_subdirectory(transforms)
13 changes: 13 additions & 0 deletions stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2024 The StableHLO Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name StablehloLinalgTransforms)
add_public_tablegen_target(StablehloLinalgTransformsPassIncGen)
Expand Down
33 changes: 0 additions & 33 deletions stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h

This file was deleted.

6 changes: 2 additions & 4 deletions stablehlo/stablehlo/conversions/linalg/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ class ModuleOp;

namespace stablehlo {

#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "stablehlo/conversions/linalg/transforms/Passes.h.inc"

std::unique_ptr<OperationPass<ModuleOp>> createStablehloLegalizeToLinalgPass();

void registerStablehloLegalizeToLinalgPass();

} // namespace stablehlo
} // namespace mlir

Expand Down
12 changes: 10 additions & 2 deletions stablehlo/stablehlo/conversions/linalg/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@ limitations under the License.

include "mlir/Pass/PassBase.td"

def StablehloLegalizeToLinalg
def StablehloLegalizeToLinalgPass
: Pass<"stablehlo-legalize-to-linalg", "mlir::ModuleOp"> {
let summary = "Legalize StableHLO to LinAlg";
let constructor = "mlir::stablehlo::createStablehloLegalizeToLinalgPass()";
let dependentDialects = [
"mlir::bufferization::BufferizationDialect",
"mlir::complex::ComplexDialect",
"mlir::linalg::LinalgDialect",
"mlir::math::MathDialect",
"mlir::memref::MemRefDialect",
"mlir::scf::SCFDialect",
"mlir::shape::ShapeDialect",
];
let options = [Option<"enablePrimitiveOps", "enable-primitive-ops", "bool",
/*default=*/"false",
"Lower to primitive Linalg ops (map, reduce and "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
Expand All @@ -39,6 +38,9 @@ limitations under the License.

namespace mlir::stablehlo {

#define GEN_PASS_DEF_STABLEHLOLEGALIZETOLINALGPASS
#include "stablehlo/conversions/linalg/transforms/Passes.h.inc"

namespace {

Value getResultValue(Operation *op) { return op->getResult(0); }
Expand Down Expand Up @@ -2570,14 +2572,9 @@ static void populateConversionPatterns(MLIRContext *context,
linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns);
}

class StablehloLegalizeToLinalgPass
: public StablehloLegalizeToLinalgBase<StablehloLegalizeToLinalgPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
scf::SCFDialect, complex::ComplexDialect, math::MathDialect,
memref::MemRefDialect, shape::ShapeDialect>();
}
struct StablehloLegalizeToLinalgPass
: impl::StablehloLegalizeToLinalgPassBase<StablehloLegalizeToLinalgPass> {
using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase;

void runOnOperation() override {
auto *context = &getContext();
Expand All @@ -2603,14 +2600,3 @@ class StablehloLegalizeToLinalgPass
};
} // namespace
} // namespace mlir::stablehlo

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::stablehlo::createStablehloLegalizeToLinalgPass() {
return std::make_unique<StablehloLegalizeToLinalgPass>();
}

void mlir::stablehlo::registerStablehloLegalizeToLinalgPass() {
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::stablehlo::createStablehloLegalizeToLinalgPass();
});
}
3 changes: 1 addition & 2 deletions stablehlo/stablehlo/conversions/tosa/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ limitations under the License.
namespace mlir {
namespace tosa {

#define GEN_PASS_DECL_STABLEHLOLEGALIZETOTOSAPASS
#define GEN_PASS_DECL_STABLEHLOPREPAREFORTOSAPASS
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "stablehlo/conversions/tosa/transforms/Passes.h.inc"

Expand Down
Loading

0 comments on commit 98f4618

Please sign in to comment.