Skip to content

Commit

Permalink
[stablehlo] fix conflict
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Xu <[email protected]>
  • Loading branch information
Connor-XY committed Nov 1, 2023
1 parent 78113b6 commit d7c1f8d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 24 deletions.
5 changes: 4 additions & 1 deletion src/Conversion/ONNXToStableHlo/ConvertONNXToStableHlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ void populateONNXToStableHloConversionPattern(
populateLoweringONNXExpandOpToStableHloPattern(patterns, ctx);
populateLoweringONNXFlattenOpToStableHloPattern(patterns, ctx);
populateLoweringONNXGatherOpToStableHloPattern(patterns, ctx);
populateLoweringONNXGatherElementsOpToStableHloPattern(patterns, ctx);
populateLoweringONNXIdentityOpToStableHloPattern(patterns, ctx);
populateLoweringONNXPadOpToStableHloPattern(patterns, ctx);
populateLoweringONNXReshapeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXShapeOpToStableHloPattern(patterns, ctx);
populateLoweringONNXSliceOpToStableHloPattern(patterns, ctx);
Expand Down Expand Up @@ -92,7 +94,8 @@ void FrontendToStableHloLoweringPass::runOnOperation() {
// Added affine as some affine maps are generated by IndexExpression. It could
// be disabled and/or replaced by shape max/min.
target.addLegalDialect<stablehlo::StablehloDialect, func::FuncDialect,
arith::ArithDialect, shape::ShapeDialect, mlir::affine::AffineDialect>();
arith::ArithDialect, shape::ShapeDialect, mlir::affine::AffineDialect,
tensor::TensorDialect>();
// Needed to support unsigned int computations. To be removed if we use a
// scheme that does not rely on the UnrealizedConversionCastOp.
target.addLegalOp<::mlir::UnrealizedConversionCastOp>();
Expand Down
36 changes: 25 additions & 11 deletions src/Conversion/ONNXToStableHlo/Tensor/GatherElements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,31 @@ struct ONNXGatherElementsOpLoweringToStableHlo : public ConversionPattern {
Value zero = getShapedZero(loc, rewriter, indices);
Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, data);
Value indicesShape = rewriter.create<shape::ShapeOfOp>(loc, indices);
Value axisDimSize =
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit);
axisDimSize =
rewriter.create<arith::IndexCastOp>(loc, indexElemType, axisDimSize);
axisDimSize = rewriter.create<tensor::FromElementsOp>(loc, axisDimSize);
axisDimSize = rewriter.create<stablehlo::ReshapeOp>(loc,
RankedTensorType::get(SmallVector<int64_t>{}, indexElemType),
axisDimSize);
Value broadcastedAxisDimSize =
rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc, indicesType,
axisDimSize, indicesShape, rewriter.getI64TensorAttr({}));
Value broadcastedAxisDimSize, axisDimSize;
if (inputType.hasStaticShape()) {
axisDimSize = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getIntegerAttr(
indexElemType, inputType.getDimSize(axisLit)));
} else {
axisDimSize =
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit);
axisDimSize =
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit);
axisDimSize =
rewriter.create<arith::IndexCastOp>(loc, indexElemType, axisDimSize);
axisDimSize = rewriter.create<tensor::FromElementsOp>(loc, axisDimSize);
axisDimSize = rewriter.create<stablehlo::ReshapeOp>(loc,
RankedTensorType::get(SmallVector<int64_t>{}, indexElemType),
axisDimSize);
}
if (indicesType.hasStaticShape()) {
broadcastedAxisDimSize = rewriter.create<stablehlo::BroadcastInDimOp>(
loc, indicesType, axisDimSize, rewriter.getI64TensorAttr({}));
} else {
broadcastedAxisDimSize =
rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc, indicesType,
axisDimSize, indicesShape, rewriter.getI64TensorAttr({}));
}
Value isNegative = rewriter.create<stablehlo::CompareOp>(
loc, indices, zero, stablehlo::ComparisonDirection::LT);
Value positiveIndices = rewriter.create<stablehlo::AddOp>(
Expand Down
8 changes: 5 additions & 3 deletions src/Conversion/ONNXToStableHlo/Tensor/Pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ struct ONNXPadOpLoweringToStablehlo : public ConversionPattern {
rewriter.getZeroAttr(elemType)));
} else {
// constantValue might be 1D tensor, reshape it to scalar
constantValue = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({}, elemType), constantValue);
ShapedType constantType = constantValue.getType().cast<ShapedType>();
if (constantType.getRank() != 0)
constantValue = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({}, elemType), constantValue);
}
SmallVector<int64_t> edgePaddingLowVec(rank, 0);
SmallVector<int64_t> edgePaddingHighVec(rank, 0);
Expand Down Expand Up @@ -91,7 +93,7 @@ struct ONNXPadOpLoweringToStablehlo : public ConversionPattern {

} // namespace

void populateLoweringONNXPadOpToStablehloPattern(
void populateLoweringONNXPadOpToStableHloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXPadOpLoweringToStablehlo>(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ func.func @main_gather_elements(%arg0: tensor<3x2xf32>, %arg1: tensor<2x2xi64>)
%0 = "onnx.GatherElements"(%arg0, %arg1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK: func.func @main_gather_elements([[PARAM_0_:%.+]]: tensor<3x2xf32>, [[PARAM_1_:%.+]]: tensor<2x2xi64>) -> tensor<2x2xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<3> : tensor<2x2xi64>
// CHECK-DAG: [[CST_:%.+]] = arith.constant dense<[2, 2, 1]> : tensor<3xindex>
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<3> : tensor<i64>
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1>
// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_0_]] : tensor<2x2xi64>
// CHECK-NEXT: [[VAR_4_:%.+]] = stablehlo.select [[VAR_2_]], [[VAR_3_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64>
// CHECK-NEXT: [[VAR_5_:%.+]] = stablehlo.reshape [[VAR_4_]] : (tensor<2x2xi64>) -> tensor<2x2x1xi64>
// CHECK-DAG: [[VAR_6_:%.+]] = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi64>
// CHECK-NEXT: [[VAR_7_:%.+]] = "stablehlo.broadcast_in_dim"([[VAR_6_]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<2x2x1xi64>
// CHECK-NEXT: [[VAR_8_:%.+]] = "stablehlo.concatenate"([[VAR_5_]], [[VAR_7_]]) {dimension = 2 : i64} : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.broadcast_in_dim [[VAR_0_]], dims = [] : (tensor<i64>) -> tensor<2x2xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1>
// CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_2_]] : tensor<2x2xi64>
// CHECK-NEXT: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64>
// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_reshape [[VAR_5_]], [[CST_]] : (tensor<2x2xi64>, tensor<3xindex>) -> tensor<2x2x1xi64>
// CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.dynamic_iota [[CST_]], dim = 1 : (tensor<3xindex>) -> tensor<2x2x1xi64>
// CHECK-NEXT: [[VAR_8_:%.+]] = stablehlo.concatenate [[VAR_6_]], [[VAR_7_]], dim = 2 : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64>
// CHECK-NEXT: [[VAR_9_:%.+]] = "stablehlo.gather"([[PARAM_0_]], [[VAR_8_]]) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<3x2xf32>, tensor<2x2x2xi64>) -> tensor<2x2xf32>
// CHECK-NEXT: return [[VAR_9_]] : tensor<2x2xf32>
}
2 changes: 1 addition & 1 deletion test/mlir/conversion/onnx_to_stablehlo/Tensor/Pad.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32>
return %3 : tensor<1x3x7x7xf32>
// CHECK-LABEL: func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> {
// CHECK-NEXT: %0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK-NEXT: %1 = "stablehlo.pad"(%arg0, %0) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<[0, 0, 1, 1]> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3x7x7xf32>
// CHECK-NEXT: %1 = stablehlo.pad %arg0, %0, low = [0, 0, 1, 1], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3x7x7xf32>
// CHECK-NEXT: return %1 : tensor<1x3x7x7xf32>
}

0 comments on commit d7c1f8d

Please sign in to comment.