Skip to content

Commit

Permalink
[StableHLO][CHLO] Add CHLO to stablehlo pipeline (iree-org#13835)
Browse files Browse the repository at this point in the history
Unlike in the MHLO pipeline, run CHLO legalization as a separate pass to
better controll required canonicalization patterns. This also decouples
the CHLO conversion from the custom type converter used during lowering
to linalg and IREE dialects.

Also add a pattern to handle non-broadcasting constants.

Issue: iree-org#13803
  • Loading branch information
kuhar authored May 29, 2023
1 parent 4c2c575 commit b218383
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ struct ConvertRankedDynamicBroadcastBinaryOp final
}
};

struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::chlo::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, op.getValue());
return success();
}
};

struct ConvertConstantLikeOp final
: OpConversionPattern<mlir::chlo::ConstantLikeOp> {
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -483,8 +494,7 @@ void populateLegalizeChloPatterns(MLIRContext* context,
context, patterns, 10);
populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
patterns
->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
context);
patterns->add<ConvertConstantOp, ConvertConstantLikeOp,
ConvertDynamicReshapeOp, ConvertSelectOp>(context);
}
} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager,
stablehlo::createLegalizeShapeComputations());
passManager.addNestedPass<func::FuncOp>(
stablehlo::createConvertStableHloToLinalgExt());
passManager.addNestedPass<func::FuncOp>(stablehlo::createLegalizeChlo());
passManager.addPass(createConvertStableHloToIreeInputDialects());
// Ensure conversion completed.
passManager.addPass(createReconcileUnrealizedCastsPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,6 @@ struct ConvertStableHloToIreeInputDialects final
// expensive expansions.
populateCanonicalizationPatterns(context, &patterns, /*benefit=*/1024);

// TODO(#12678): Handle chlo lowering.

populateStableHloToLinalgOnTensorsConversionPatterns(
context, *typeConverter, &patterns);
populateStableHloCollectivesConversionPatterns(context, *typeConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.

// CHECK-LABEL: @constants
func.func @constants() -> (tensor<4xi32>, tensor<2x2xf32>) {
%0 = chlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%1 = chlo.constant dense<0.0> : tensor<2x2xf32>

// CHECK-DAG: stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32>
func.return %0, %1 : tensor<4xi32>, tensor<2x2xf32>
}

// -----

// CHECK-LABEL: @addWithoutBroadcast
func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: stablehlo.add %arg0, %arg1
Expand Down

0 comments on commit b218383

Please sign in to comment.