Skip to content

Commit

Permalink
[stablehlo] fix ConvTranspose
Browse files Browse the repository at this point in the history
Signed-off-by: yan.xu0210 <[email protected]>
  • Loading branch information
Connor-XY committed Mar 8, 2024
1 parent 9202207 commit a719175
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 63 deletions.
54 changes: 8 additions & 46 deletions src/Conversion/ONNXToStablehlo/NN/ConvTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ struct ONNXConvTransposeOpLoweringToStablehlo : public ConversionPattern {
llvm::SmallVector<int64_t, 2> strides = shapeHelper.strides;
llvm::SmallVector<int64_t, 2> dilations = shapeHelper.dilations;
llvm::SmallVector<int64_t, 2> outputPadding = shapeHelper.outputPadding;
bool needOutputPadding = std::any_of(outputPadding.begin(),
outputPadding.end(), [](int64_t i) { return i != 0; });

Value inputOperand = operandAdaptor.getX();
Value filterOperand = operandAdaptor.getW();
Expand All @@ -96,23 +94,7 @@ struct ONNXConvTransposeOpLoweringToStablehlo : public ConversionPattern {
int64_t rank = inputType.getRank();
int64_t kernelSize = kernelShape.size();

Type outputType = *op->result_type_begin();
Type convOutputType;
if (!needOutputPadding)
convOutputType = outputType;
else {
// use the shape inference result of shapeHelper
llvm::SmallVector<IndexExpr, 2> dimsNoOutputPadding =
shapeHelper.dimsNoOutputPadding;
SmallVector<int64_t> convOutputShape;
for (int i = 0; i < rank; ++i) {
if (dimsNoOutputPadding[i].isLiteral())
convOutputShape.emplace_back(dimsNoOutputPadding[i].getLiteral());
else
convOutputShape.emplace_back(ShapedType::kDynamic);
}
convOutputType = RankedTensorType::get(convOutputShape, elemType);
}
Type convOutputType = *op->result_type_begin();

SmallVector<int64_t> spatialDimensions;
for (int64_t i = spatialOffset; i < rank; i++) {
Expand All @@ -136,7 +118,7 @@ struct ONNXConvTransposeOpLoweringToStablehlo : public ConversionPattern {
pads[i].getLiteral());
flattenPaddings.push_back(
dilations[i] * (kernelShape[i].getLiteral() - 1) -
pads[i + spatialRank].getLiteral());
pads[i + spatialRank].getLiteral() + outputPadding[i]);
}

stablehlo::ConvDimensionNumbersAttr dimension_numbers =
Expand Down Expand Up @@ -175,37 +157,17 @@ struct ONNXConvTransposeOpLoweringToStablehlo : public ConversionPattern {
dilations),
nullptr, dimension_numbers, groupNum, 1, nullptr);

Value padResult;
if (!needOutputPadding) {
padResult = convResult;
} else {
SmallVector<int64_t> edgePaddingLowVec(rank, 0);
SmallVector<int64_t> edgePaddingHighVec(rank, 0);
SmallVector<int64_t> interiorPaddingVec(rank, 0);
std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2);
Value zeroPaddingValue = rewriter.create<stablehlo::ConstantOp>(
loc, DenseElementsAttr::get(mlir::RankedTensorType::get({}, elemType),
rewriter.getZeroAttr(elemType)));
mlir::DenseI64ArrayAttr edgePaddingLow =
rewriter.getDenseI64ArrayAttr(edgePaddingLowVec);
mlir::DenseI64ArrayAttr edgePaddingHigh =
rewriter.getDenseI64ArrayAttr(edgePaddingHighVec);
mlir::DenseI64ArrayAttr interiorPadding =
rewriter.getDenseI64ArrayAttr(interiorPaddingVec);
padResult = rewriter.create<stablehlo::PadOp>(loc, outputType, convResult,
zeroPaddingValue, edgePaddingLow, edgePaddingHigh, interiorPadding);
}

Value addBiasResult;
if (!hasBias) {
addBiasResult = padResult;
addBiasResult = convResult;
} else {
Value finalB;
Value resultShape = rewriter.create<shape::ShapeOfOp>(loc, padResult);
Value resultShape = rewriter.create<shape::ShapeOfOp>(loc, convResult);
finalB = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc,
outputType, biasOperand, resultShape, rewriter.getI64TensorAttr({1}));
addBiasResult = rewriter.create<stablehlo::AddOp>(loc, padResult, finalB);
convOutputType, biasOperand, resultShape,
rewriter.getI64TensorAttr({1}));
addBiasResult =
rewriter.create<stablehlo::AddOp>(loc, convResult, finalB);
}

rewriter.replaceOp(op, addBiasResult);
Expand Down
9 changes: 0 additions & 9 deletions src/Dialect/ONNX/ONNXOps/NN/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,6 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() {

// Save the final result.
setOutputDims(outputDims);

dimsNoOutputPadding.emplace_back(outputDims[0]);
dimsNoOutputPadding.emplace_back(outputDims[1]);
for (int i = 0; i < spatialRank; ++i) {
LiteralIndexExpr outPad(outputPadding[i]);
IndexExpr dimNoOutPad =
IndexExpr::max(zeroIE, outputDims[i + spatialOffset] - outPad);
dimsNoOutputPadding.emplace_back(dimNoOutPad);
}
return success();
}

Expand Down
3 changes: 1 addition & 2 deletions src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ struct ONNXConvTransposeOpShapeHelper : public ONNXOpShapeHelper {
ONNXConvTransposeOpShapeHelper(mlir::Operation *op, mlir::ValueRange operands,
IndexExprBuilder *ieBuilder = nullptr, IndexExprScope *scope = nullptr)
: ONNXOpShapeHelper(op, operands, ieBuilder, scope), kernelShape(),
pads(), strides(), dilations(), outputPadding(), dimsNoOutputPadding() {
pads(), strides(), dilations(), outputPadding() {
}
virtual ~ONNXConvTransposeOpShapeHelper() {}
mlir::LogicalResult computeShape() final;
Expand All @@ -478,7 +478,6 @@ struct ONNXConvTransposeOpShapeHelper : public ONNXOpShapeHelper {
llvm::SmallVector<int64_t, 2> strides;
llvm::SmallVector<int64_t, 2> dilations;
llvm::SmallVector<int64_t, 2> outputPadding;
llvm::SmallVector<IndexExpr, 2> dimsNoOutputPadding;
};

//===----------------------------------------------------------------------===//
Expand Down
10 changes: 4 additions & 6 deletions test/mlir/conversion/onnx_to_stablehlo/NN/ConvTranspose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,10 @@ func.func @test_attributes_1(%arg0 : tensor<?x1x3x3xf32>, %arg1 : tensor<1x2x3x3

// CHECK-LABEL: func.func @test_attributes_1
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) -> tensor<?x2x10x8xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.reverse [[PARAM_1_]], dims = [2, 3] : tensor<1x2x3x3xf32>
// CHECK: [[VAR_2_:%.+]] = stablehlo.transpose [[VAR_1_]], dims = [1, 0, 2, 3] : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32>
// CHECK: [[VAR_3_:%.+]] = stablehlo.convolution([[PARAM_0_]], [[VAR_2_]]) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = {{.}}[2, 2], [2, 2]{{.}}, lhs_dilate = [3, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<?x1x3x3xf32>, tensor<2x1x3x3xf32>) -> tensor<?x2x9x7xf32>
// CHECK: [[VAR_4_:%.+]] = stablehlo.pad [[VAR_3_]], [[VAR_0_]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<?x2x9x7xf32>, tensor<f32>) -> tensor<?x2x10x8xf32>
// CHECK: return [[VAR_4_]] : tensor<?x2x10x8xf32>
// CHECK: [[VAR_0_:%.+]] = stablehlo.reverse [[PARAM_1_]], dims = [2, 3] : tensor<1x2x3x3xf32>
// CHECK: [[VAR_1_:%.+]] = stablehlo.transpose [[VAR_0_]], dims = [1, 0, 2, 3] : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32>
// CHECK: [[VAR_2_:%.+]] = stablehlo.convolution([[PARAM_0_]], [[VAR_1_]]) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = {{.}}[2, 3], [2, 3]{{.}}, lhs_dilate = [3, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<?x1x3x3xf32>, tensor<2x1x3x3xf32>) -> tensor<?x2x10x8xf32>
// CHECK: return [[VAR_2_]] : tensor<?x2x10x8xf32>
// CHECK: }

// -----
Expand Down

0 comments on commit a719175

Please sign in to comment.