Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into concat-refine
Browse files Browse the repository at this point in the history
  • Loading branch information
chentong319 committed Oct 11, 2022
2 parents a74cb69 + 447dc54 commit e683d53
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 151 deletions.
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ add_onnx_mlir_library(OMONNXOps
ShapeInference/GatherND.cpp
ShapeInference/Gemm.cpp
ShapeInference/LRN.cpp
ShapeInference/MatMul.cpp
ShapeInference/GenericMatMul.cpp
ShapeInference/MaxPool.cpp
ShapeInference/ONNXShapeHelper.cpp
ShapeInference/OneHot.cpp
Expand Down
138 changes: 16 additions & 122 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,130 +1405,29 @@ LogicalResult ONNXMatMulOp::inferShapes(

LogicalResult ONNXQLinearMatMulOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
// Cannot infer shape if no shape exists.
if (!a().getType().isa<RankedTensorType>() ||
!b().getType().isa<RankedTensorType>())
return success();

auto lhsTy = a().getType().cast<RankedTensorType>();
auto rhsTy = b().getType().cast<RankedTensorType>();
auto elementType = getResult().getType().cast<ShapedType>().getElementType();
return shapeHelperInferShapes<ONNXQLinearMatMulOpShapeHelper,
ONNXQLinearMatMulOp, ONNXQLinearMatMulOpAdaptor>(*this, elementType);
}

SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape();
auto rhsShape = rhsTy.getShape();

if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed.
return emitError("Multiplication by scalar arguments not allowed");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if
// all sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
return emitError("Attempt to multiply incompatible matrices");
dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
// If the first argument is 1-D, it is promoted to a matrix by prepending
// a 1 to its dimensions. After matrix multiplication the prepended 1 is
// removed.
//
// N MATMUL (s1 x s2 x... x sK x N x P)
// =>
// (s1 x s2 x... x sK x P)

// Check legality of matrix multiplication.
unsigned rhsRank = rhsShape.size();
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[0] != rhsShape[rhsRank - 2])
return emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
// If the second argument is 1-D, it is promoted to a matrix by appending
// a 1 to its dimensions. After matrix multiplication the appended 1 is
// removed.
//
// (s1 x s2 x... x sK x M x N) MATMUL N
// =>
// (s1 x s2 x... x sK x M)

// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
return emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
dims.emplace_back(lhsShape[i]);
dims.emplace_back(lhsShape[lhsRank - 2]);
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
// =>
// (s1 x s2 x... x sK x M x P)

// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
return emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]);
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
// =>
// (s1 x s2 x... x sK x M x P)

// Check legality of matrix multiplication.
unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2])
return emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
// =>
// (u1 x u2 x... x uK x M x P)

// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
return emitError("Attempt to multiply incompatible matrices");
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
lhsBcastShape.emplace_back(lhsShape[i]);
SmallVector<int64_t, 2> rhsBcastShape;
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
return emitError("Broadcasted dimensions are incompatible");
dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else {
// This case covers all remaining combinations of 1 and 2-D matrices.
int64_t lhsDim = lhsShape[0];
int64_t rhsDim = rhsShape[0];
if (lhsShape.size() > 1) {
lhsDim = lhsShape[1];
dims.emplace_back(lhsShape[0]);
}
//===----------------------------------------------------------------------===//
// MatMulIntegerOp
//===----------------------------------------------------------------------===//

// Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
return emitError("Attempt to multiply incompatible matrices");
if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]);
}
LogicalResult ONNXMatMulIntegerOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>())
return success();

getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
return success();
auto elementType = getResult().getType().cast<ShapedType>().getElementType();
return shapeHelperInferShapes<ONNXMatMulIntegerOpShapeHelper,
ONNXMatMulIntegerOp, ONNXMatMulIntegerOpAdaptor>(*this, elementType);
}

// GemmOp
Expand Down Expand Up @@ -4529,11 +4428,6 @@ LogicalResult ONNXLpPoolOp::inferShapes(
return emitError(NOT_IMPLEMENTED_MESSAGE);
}

LogicalResult ONNXMatMulIntegerOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
return emitError(NOT_IMPLEMENTED_MESSAGE);
}

LogicalResult ONNXMaxPoolOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
return emitError(NOT_IMPLEMENTED_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,49 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===------------ MatMul.cpp - Shape Inference for MatMul Op --------------===//
//===--------- GenericMatMul.cpp - Shape Inference for matmul Ops ---------===//
//
// This file implements shape inference for the ONNX MatMul Operator.
// This file implements shape inference for the ONNX MatMul, QLinearMatMul, and
// MatMulInteger Operators.
//
//===----------------------------------------------------------------------===//

#include <tuple>
#include <utility>

#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"

using namespace mlir;

namespace onnx_mlir {

ONNXMatMulOpShapeHelper::ONNXMatMulOpShapeHelper(
ONNXMatMulOp *newOp, IndexExprScope *inScope)
: ONNXOpShapeHelper<ONNXMatMulOp>(
newOp, newOp->getOperation()->getNumResults(), inScope),
aDims(), bDims(), aPadDims(), bPadDims() {}
template <typename OpAdaptor>
std::pair<Value, Value> matMulInputs(OpAdaptor &operandAdaptor) {
Value A = operandAdaptor.A();
Value B = operandAdaptor.B();
return std::pair(A, B);
}

ONNXMatMulOpShapeHelper::ONNXMatMulOpShapeHelper(ONNXMatMulOp *newOp,
OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal, IndexExprScope *inScope)
: ONNXOpShapeHelper<ONNXMatMulOp>(newOp,
newOp->getOperation()->getNumResults(), rewriter, fGetDenseVal,
fLoadVal, inScope),
aDims(), bDims(), aPadDims(), bPadDims() {}
template <>
std::pair<Value, Value> matMulInputs(
ONNXQLinearMatMulOpAdaptor &operandAdaptor) {
Value A = operandAdaptor.a();
Value B = operandAdaptor.b();
return std::pair(A, B);
}

template <typename OP_TYPE>
LogicalResult ONNXGenericMatMulOpShapeHelper<OP_TYPE>::computeShape(
typename OP_TYPE::Adaptor operandAdaptor) {

LogicalResult ONNXMatMulOpShapeHelper::computeShape(
ONNXMatMulOpAdaptor operandAdaptor) {
// Shape inference indicated by passing a null rewriter pointer.
// Output dims of result.
DimsExpr outputDims;

// Get info.
Value A = operandAdaptor.A();
Value B = operandAdaptor.B();
Value A;
Value B;
std::tie(A, B) = matMulInputs(operandAdaptor);
MemRefBoundsIndexCapture ABounds(A);
MemRefBoundsIndexCapture BBounds(B);

Expand Down Expand Up @@ -93,7 +101,7 @@ LogicalResult ONNXMatMulOpShapeHelper::computeShape(
} else if (aDims[i].isLiteral() && bDims[i].isLiteral()) {
// No broadcast, both literals, make sure they have the same value.
if (aDims[i].getLiteral() != bDims[i].getLiteral())
return op->emitError("Incompatible size detected");
return this->op->emitError("Incompatible size detected");
outputDims.emplace_back(aDims[i]);
} else if (aDims[i].isLiteral()) {
// A dim is a literal; use it here for output and b, since b
Expand All @@ -119,7 +127,7 @@ LogicalResult ONNXMatMulOpShapeHelper::computeShape(
// And test the K dimensions.
if (aDims[aK].isLiteral() && bDims[bK].isLiteral()) {
if (aDims[aK].getLiteral() != bDims[bK].getLiteral())
return op->emitError("reduction dimension must be the same");
return this->op->emitError("reduction dimension must be the same");
} else if (aDims[aK].isLiteral()) {
// Save aK dims into bK dims, in case bK dims was runtime
bDims[bK] = aDims[aK];
Expand All @@ -137,8 +145,12 @@ LogicalResult ONNXMatMulOpShapeHelper::computeShape(
assert(outputDims.empty() && "1-D x 1-D results in scalar");
}
// Save the final result.
setOutputDims(outputDims);
this->setOutputDims(outputDims);
return success();
}

template struct ONNXGenericMatMulOpShapeHelper<ONNXMatMulOp>;
template struct ONNXGenericMatMulOpShapeHelper<ONNXMatMulIntegerOp>;
template struct ONNXGenericMatMulOpShapeHelper<ONNXQLinearMatMulOp>;

} // namespace onnx_mlir
2 changes: 2 additions & 0 deletions src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ template struct ONNXOpShapeHelper<ONNXGatherOp>;
template struct ONNXOpShapeHelper<ONNXGatherElementsOp>;
template struct ONNXOpShapeHelper<ONNXGatherNDOp>;
template struct ONNXOpShapeHelper<ONNXGemmOp>;
template struct ONNXOpShapeHelper<ONNXQLinearMatMulOp>;
template struct ONNXOpShapeHelper<ONNXMatMulIntegerOp>;
template struct ONNXOpShapeHelper<ONNXMatMulOp>;
template struct ONNXOpShapeHelper<ONNXMaxPoolSingleOutOp>;
template struct ONNXOpShapeHelper<ONNXOneHotOp>;
Expand Down
28 changes: 21 additions & 7 deletions src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,22 +314,36 @@ struct ONNXGemmOpShapeHelper : public ONNXOpShapeHelper<mlir::ONNXGemmOp> {
int cRank; // Dim of the original C (not padding dims by 1).
};

// Shape for MatMulOp.
struct ONNXMatMulOpShapeHelper : public ONNXOpShapeHelper<mlir::ONNXMatMulOp> {
ONNXMatMulOpShapeHelper(
mlir::ONNXMatMulOp *newOp, IndexExprScope *inScope = nullptr);
ONNXMatMulOpShapeHelper(mlir::ONNXMatMulOp *newOp, mlir::OpBuilder *rewriter,
template <typename OP_TYPE>
struct ONNXGenericMatMulOpShapeHelper : public ONNXOpShapeHelper<OP_TYPE> {
ONNXGenericMatMulOpShapeHelper(
OP_TYPE *newOp, IndexExprScope *inScope = nullptr)
: ONNXOpShapeHelper<OP_TYPE>(
newOp, newOp->getOperation()->getNumResults(), inScope),
aDims(), bDims(), aPadDims(), bPadDims() {}
ONNXGenericMatMulOpShapeHelper(OP_TYPE *newOp, mlir::OpBuilder *rewriter,
ArrayValueIndexCapture::GetDenseVal fGetDenseVal,
ArrayValueIndexCapture::LoadVal fLoadVal,
IndexExprScope *inScope = nullptr);
mlir::LogicalResult computeShape(mlir::ONNXMatMulOpAdaptor operandAdaptor);
IndexExprScope *inScope = nullptr)
: ONNXOpShapeHelper<OP_TYPE>(newOp,
newOp->getOperation()->getNumResults(), rewriter, fGetDenseVal,
fLoadVal, inScope),
aDims(), bDims(), aPadDims(), bPadDims() {}
mlir::LogicalResult computeShape(typename OP_TYPE::Adaptor operandAdaptor);
// Additional data for MatMulOp: output = a & b.
llvm::SmallVector<IndexExpr, 4> aDims,
bDims; // Dim after applying padding.
llvm::BitVector aPadDims,
bPadDims; // When true, that dim was padded.
};

using ONNXMatMulOpShapeHelper =
ONNXGenericMatMulOpShapeHelper<mlir::ONNXMatMulOp>;
using ONNXMatMulIntegerOpShapeHelper =
ONNXGenericMatMulOpShapeHelper<mlir::ONNXMatMulIntegerOp>;
using ONNXQLinearMatMulOpShapeHelper =
ONNXGenericMatMulOpShapeHelper<mlir::ONNXQLinearMatMulOp>;

// Shape for PadOp.
struct ONNXPadOpShapeHelper : public ONNXOpShapeHelper<mlir::ONNXPadOp> {
ONNXPadOpShapeHelper(
Expand Down
27 changes: 27 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,33 @@ func.func @test_matmul_10(%arg0 : tensor<?x42x32xf32>, %arg1 : tensor<32xf32>) -

// -----

/// QLinearMatMul

func.func @test_qlinearmatmul_1(%arg0: tensor<2x2x4xui8>, %arg1: tensor<1xf32>, %arg2: tensor<1xui8>, %arg3: tensor<2x4x3xui8>, %arg4: tensor<1xf32>, %arg5: tensor<1xui8>, %arg6: tensor<1xf32>, %arg7: tensor<1xui8>) -> tensor<*xui8> {
%0 = "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2x2x4xui8>, tensor<1xf32>, tensor<1xui8>, tensor<2x4x3xui8>, tensor<1xf32>, tensor<1xui8>, tensor<1xf32>, tensor<1xui8>) -> tensor<*xui8>
"func.return"(%0) : (tensor<*xui8>) -> ()


// CHECK-LABEL: test_qlinearmatmul_1
// CHECK: [[RES1:%.+]] = "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2x2x4xui8>, tensor<1xf32>, tensor<1xui8>, tensor<2x4x3xui8>, tensor<1xf32>, tensor<1xui8>, tensor<1xf32>, tensor<1xui8>) -> tensor<2x2x3xui8>
// CHECK: return [[RES1]] : tensor<2x2x3xui8>
}

// -----

/// MatMulInteger

func.func @test_matmulinteger_1(%arg0: tensor<4x3xui8>, %arg1: tensor<3x2xui8>, %arg2: tensor<1xui8>, %arg3: tensor<1xui8>) -> tensor<*xi32> {
%0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x3xui8>, tensor<3x2xui8>, tensor<1xui8>, tensor<1xui8>) -> tensor<*xi32>
return %0 : tensor<*xi32>

// CHECK-LABEL: test_matmulinteger_1
// CHECK: [[RES1:%.+]] = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x3xui8>, tensor<3x2xui8>, tensor<1xui8>, tensor<1xui8>) -> tensor<4x2xi32>
// CHECK: return [[RES1]] : tensor<4x2xi32>
}

// -----

//===----------------------------------------------------------------------===//
/// Test shape inference for Conv (first with no bias) operation and all its attributes.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit e683d53

Please sign in to comment.