diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp index 2e2f4ac4222f..f335483545c5 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeCHLO.cpp @@ -279,6 +279,17 @@ struct ConvertRankedDynamicBroadcastBinaryOp final } }; +struct ConvertConstantOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::chlo::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + struct ConvertConstantLikeOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -483,8 +494,7 @@ void populateLegalizeChloPatterns(MLIRContext* context, context, patterns, 10); populateForBroadcastingBinaryOp( context, patterns, 5); - patterns - ->add( - context); + patterns->add(context); } } // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp index f44d4826a200..278287e9f2a5 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp @@ -110,6 +110,7 @@ void buildStableHLOInputConversionPassPipelineImpl(OpPassManager &passManager, stablehlo::createLegalizeShapeComputations()); passManager.addNestedPass( stablehlo::createConvertStableHloToLinalgExt()); + passManager.addNestedPass(stablehlo::createLegalizeChlo()); passManager.addPass(createConvertStableHloToIreeInputDialects()); // Ensure conversion completed. passManager.addPass(createReconcileUnrealizedCastsPass()); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp index 5179ce6e3d49..f4bcdf462b9a 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp @@ -432,8 +432,6 @@ struct ConvertStableHloToIreeInputDialects final // expensive expansions. populateCanonicalizationPatterns(context, &patterns, /*benefit=*/1024); - // TODO(#12678): Handle chlo lowering. - populateStableHloToLinalgOnTensorsConversionPatterns( context, *typeConverter, &patterns); populateStableHloCollectivesConversionPatterns(context, *typeConverter, diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/legalize_chlo_no_broadcast.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/legalize_chlo_no_broadcast.mlir index 4bc4c3b90d12..b12486376071 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/legalize_chlo_no_broadcast.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/legalize_chlo_no_broadcast.mlir @@ -4,6 +4,18 @@ // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. +// CHECK-LABEL: @constants +func.func @constants() -> (tensor<4xi32>, tensor<2x2xf32>) { + %0 = chlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %1 = chlo.constant dense<0.0> : tensor<2x2xf32> + + // CHECK-DAG: stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK-DAG: stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32> + func.return %0, %1 : tensor<4xi32>, tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: @addWithoutBroadcast func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.add %arg0, %arg1