diff --git a/src/Conversion/ONNXToStableHlo/ConvertONNXToStableHlo.cpp b/src/Conversion/ONNXToStableHlo/ConvertONNXToStableHlo.cpp index 082a53c17b..4f5d15b09c 100644 --- a/src/Conversion/ONNXToStableHlo/ConvertONNXToStableHlo.cpp +++ b/src/Conversion/ONNXToStableHlo/ConvertONNXToStableHlo.cpp @@ -42,7 +42,9 @@ void populateONNXToStableHloConversionPattern( populateLoweringONNXExpandOpToStableHloPattern(patterns, ctx); populateLoweringONNXFlattenOpToStableHloPattern(patterns, ctx); populateLoweringONNXGatherOpToStableHloPattern(patterns, ctx); + populateLoweringONNXGatherElementsOpToStableHloPattern(patterns, ctx); populateLoweringONNXIdentityOpToStableHloPattern(patterns, ctx); + populateLoweringONNXPadOpToStableHloPattern(patterns, ctx); populateLoweringONNXReshapeOpToStableHloPattern(patterns, ctx); populateLoweringONNXShapeOpToStableHloPattern(patterns, ctx); populateLoweringONNXSliceOpToStableHloPattern(patterns, ctx); @@ -92,7 +94,8 @@ void FrontendToStableHloLoweringPass::runOnOperation() { // Added affine as some affine maps are generated by IndexExpression. It could // be disabled and/or replaced by shape max/min. target.addLegalDialect(); + arith::ArithDialect, shape::ShapeDialect, mlir::affine::AffineDialect, + tensor::TensorDialect>(); // Needed to support unsigned int computations. To be removed if we use a // scheme that does not rely on the UnrealizedConversionCastOp. target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); diff --git a/src/Conversion/ONNXToStableHlo/Tensor/GatherElements.cpp b/src/Conversion/ONNXToStableHlo/Tensor/GatherElements.cpp index 9100eb93b2..1b320421bd 100644 --- a/src/Conversion/ONNXToStableHlo/Tensor/GatherElements.cpp +++ b/src/Conversion/ONNXToStableHlo/Tensor/GatherElements.cpp @@ -57,17 +57,31 @@ struct ONNXGatherElementsOpLoweringToStableHlo : public ConversionPattern { Value zero = getShapedZero(loc, rewriter, indices); Value inputShape = rewriter.create(loc, data); Value indicesShape = rewriter.create(loc, indices); - Value axisDimSize = - rewriter.create(loc, inputShape, axisLit); - axisDimSize = - rewriter.create(loc, indexElemType, axisDimSize); - axisDimSize = rewriter.create(loc, axisDimSize); - axisDimSize = rewriter.create(loc, - RankedTensorType::get(SmallVector{}, indexElemType), - axisDimSize); - Value broadcastedAxisDimSize = - rewriter.create(loc, indicesType, - axisDimSize, indicesShape, rewriter.getI64TensorAttr({})); + Value broadcastedAxisDimSize, axisDimSize; + if (inputType.hasStaticShape()) { + axisDimSize = rewriter.create( + loc, rewriter.getIntegerAttr( + indexElemType, inputType.getDimSize(axisLit))); + } else { + axisDimSize = + rewriter.create(loc, inputShape, axisLit); + axisDimSize = + rewriter.create(loc, inputShape, axisLit); + axisDimSize = + rewriter.create(loc, indexElemType, axisDimSize); + axisDimSize = rewriter.create(loc, axisDimSize); + axisDimSize = rewriter.create(loc, + RankedTensorType::get(SmallVector{}, indexElemType), + axisDimSize); + } + if (indicesType.hasStaticShape()) { + broadcastedAxisDimSize = rewriter.create( + loc, indicesType, axisDimSize, rewriter.getI64TensorAttr({})); + } else { + broadcastedAxisDimSize = + rewriter.create(loc, indicesType, + axisDimSize, indicesShape, rewriter.getI64TensorAttr({})); + } Value isNegative = rewriter.create( loc, indices, zero, stablehlo::ComparisonDirection::LT); Value positiveIndices = rewriter.create( diff --git a/src/Conversion/ONNXToStableHlo/Tensor/Pad.cpp b/src/Conversion/ONNXToStableHlo/Tensor/Pad.cpp index 8e3b7162c3..9d305e5ed2 100644 --- a/src/Conversion/ONNXToStableHlo/Tensor/Pad.cpp +++ b/src/Conversion/ONNXToStableHlo/Tensor/Pad.cpp @@ -52,8 +52,10 @@ struct ONNXPadOpLoweringToStablehlo : public ConversionPattern { rewriter.getZeroAttr(elemType))); } else { // constantValue might be 1D tensor, reshape it to scalar - constantValue = rewriter.create( - loc, RankedTensorType::get({}, elemType), constantValue); + ShapedType constantType = constantValue.getType().cast(); + if (constantType.getRank() != 0) + constantValue = rewriter.create( + loc, RankedTensorType::get({}, elemType), constantValue); } SmallVector edgePaddingLowVec(rank, 0); SmallVector edgePaddingHighVec(rank, 0); @@ -91,7 +93,7 @@ struct ONNXPadOpLoweringToStablehlo : public ConversionPattern { } // namespace -void populateLoweringONNXPadOpToStablehloPattern( +void populateLoweringONNXPadOpToStableHloPattern( RewritePatternSet &patterns, MLIRContext *ctx) { patterns.insert(ctx); } diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir index 2c2662d383..9694383137 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir @@ -4,15 +4,16 @@ func.func @main_gather_elements(%arg0: tensor<3x2xf32>, %arg1: tensor<2x2xi64>) %0 = "onnx.GatherElements"(%arg0, %arg1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> // CHECK: func.func @main_gather_elements([[PARAM_0_:%.+]]: tensor<3x2xf32>, [[PARAM_1_:%.+]]: tensor<2x2xi64>) -> tensor<2x2xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<3> : tensor<2x2xi64> +// CHECK-DAG: [[CST_:%.+]] = arith.constant dense<[2, 2, 1]> : tensor<3xindex> +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<3> : tensor // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_0_]] : tensor<2x2xi64> -// CHECK-NEXT: [[VAR_4_:%.+]] = stablehlo.select [[VAR_2_]], [[VAR_3_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64> -// CHECK-NEXT: [[VAR_5_:%.+]] = stablehlo.reshape [[VAR_4_]] : (tensor<2x2xi64>) -> tensor<2x2x1xi64> -// CHECK-DAG: [[VAR_6_:%.+]] = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi64> -// CHECK-NEXT: [[VAR_7_:%.+]] = "stablehlo.broadcast_in_dim"([[VAR_6_]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<2x2x1xi64> -// CHECK-NEXT: [[VAR_8_:%.+]] = "stablehlo.concatenate"([[VAR_5_]], [[VAR_7_]]) {dimension = 2 : i64} : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.broadcast_in_dim [[VAR_0_]], dims = [] : (tensor) -> tensor<2x2xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_2_]] : tensor<2x2xi64> +// CHECK-NEXT: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_reshape [[VAR_5_]], [[CST_]] : (tensor<2x2xi64>, tensor<3xindex>) -> tensor<2x2x1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.dynamic_iota [[CST_]], dim = 1 : (tensor<3xindex>) -> tensor<2x2x1xi64> +// CHECK-NEXT: [[VAR_8_:%.+]] = stablehlo.concatenate [[VAR_6_]], [[VAR_7_]], dim = 2 : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64> // CHECK-NEXT: [[VAR_9_:%.+]] = "stablehlo.gather"([[PARAM_0_]], [[VAR_8_]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<3x2xf32>, tensor<2x2x2xi64>) -> tensor<2x2xf32> // CHECK-NEXT: return [[VAR_9_]] : tensor<2x2xf32> } \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Pad.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Pad.mlir index 7d7292c691..8b2e9702f5 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Pad.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Pad.mlir @@ -8,6 +8,6 @@ func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> return %3 : tensor<1x3x7x7xf32> // CHECK-LABEL: func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> { // CHECK-NEXT: %0 = stablehlo.constant dense<2.000000e+00> : tensor -// CHECK-NEXT: %1 = "stablehlo.pad"(%arg0, %0) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<[0, 0, 1, 1]> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x3x5x5xf32>, tensor) -> tensor<1x3x7x7xf32> +// CHECK-NEXT: %1 = stablehlo.pad %arg0, %0, low = [0, 0, 1, 1], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x3x5x5xf32>, tensor) -> tensor<1x3x7x7xf32> // CHECK-NEXT: return %1 : tensor<1x3x7x7xf32> }