From 447dc54140fc5447df1e1c54b157ad7fd85b0b0b Mon Sep 17 00:00:00 2001 From: Philip Lassen Date: Tue, 11 Oct 2022 03:45:16 -0700 Subject: [PATCH] add MatMulInt shape inference (#1761) * add MatMulInt shape inference Signed-off-by: Philip Lassen * An attempt Signed-off-by: Philip Lassen * Fix matmul Signed-off-by: Philip Lassen * Use macros in general matmul shape inference implementation Signed-off-by: Philip Lassen * Add matmul tests Signed-off-by: Philip Lassen * Replace subclassing with typedefs and removed macros Signed-off-by: Philip Lassen * Fix msvc warnings Signed-off-by: Philip Lassen * Remove static assert Signed-off-by: Philip Lassen Signed-off-by: Philip Lassen Signed-off-by: Philip Lassen --- src/Dialect/ONNX/CMakeLists.txt | 2 +- src/Dialect/ONNX/ONNXOps.cpp | 138 ++---------------- .../{MatMul.cpp => GenericMatMul.cpp} | 54 ++++--- .../ONNX/ShapeInference/ONNXShapeHelper.cpp | 2 + .../ONNX/ShapeInference/ONNXShapeHelper.hpp | 28 +++- test/mlir/onnx/onnx_shape_inference.mlir | 27 ++++ 6 files changed, 100 insertions(+), 151 deletions(-) rename src/Dialect/ONNX/ShapeInference/{MatMul.cpp => GenericMatMul.cpp} (79%) diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 8b6eaf4b8b..dbb26bbbc4 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -35,7 +35,7 @@ add_onnx_mlir_library(OMONNXOps ShapeInference/GatherND.cpp ShapeInference/Gemm.cpp ShapeInference/LRN.cpp - ShapeInference/MatMul.cpp + ShapeInference/GenericMatMul.cpp ShapeInference/MaxPool.cpp ShapeInference/ONNXShapeHelper.cpp ShapeInference/OneHot.cpp diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 6a662fce8e..561f08909e 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1405,130 +1405,29 @@ LogicalResult ONNXMatMulOp::inferShapes( LogicalResult ONNXQLinearMatMulOp::inferShapes( std::function doShapeInference) { - // Cannot infer shape if no shape exists. if (!a().getType().isa() || !b().getType().isa()) return success(); - auto lhsTy = a().getType().cast(); - auto rhsTy = b().getType().cast(); + auto elementType = getResult().getType().cast().getElementType(); + return shapeHelperInferShapes(*this, elementType); +} - SmallVector dims; - auto lhsShape = lhsTy.getShape(); - auto rhsShape = rhsTy.getShape(); - - if (lhsShape.size() < 1 && rhsShape.size() < 1) { - // Multiplication by scalars is not allowed. - return emitError("Multiplication by scalar arguments not allowed"); - } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { - // Special case when both arrays are 1-dimensional and according to - // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes - // need to be removed after the multiplication but cannot be removed if - // all sizes are 1. - if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) - return emitError("Attempt to multiply incompatible matrices"); - dims.emplace_back(1); - } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { - // If the first argument is 1-D, it is promoted to a matrix by prepending - // a 1 to its dimensions. After matrix multiplication the prepended 1 is - // removed. - // - // N MATMUL (s1 x s2 x... x sK x N x P) - // => - // (s1 x s2 x... x sK x P) - - // Check legality of matrix multiplication. - unsigned rhsRank = rhsShape.size(); - if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && - lhsShape[0] != rhsShape[rhsRank - 2]) - return emitError("Attempt to multiply incompatible matrices"); - for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) - dims.emplace_back(rhsShape[i]); - dims.emplace_back(rhsShape[rhsRank - 1]); - } else if (lhsShape.size() >= 2 && rhsShape.size() == 1) { - // If the second argument is 1-D, it is promoted to a matrix by appending - // a 1 to its dimensions. After matrix multiplication the appended 1 is - // removed. - // - // (s1 x s2 x... x sK x M x N) MATMUL N - // => - // (s1 x s2 x... x sK x M) - - // Check legality of matrix multiplication. - unsigned lhsRank = lhsShape.size(); - if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && - lhsShape[lhsRank - 1] != rhsShape[0]) - return emitError("Attempt to multiply incompatible matrices"); - for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) - dims.emplace_back(lhsShape[i]); - dims.emplace_back(lhsShape[lhsRank - 2]); - } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { - // (s1 x s2 x... x sK x M x N) MATMUL (N x P) - // => - // (s1 x s2 x... x sK x M x P) - - // Check legality of matrix multiplication. - unsigned lhsRank = lhsShape.size(); - if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && - lhsShape[lhsRank - 1] != rhsShape[0]) - return emitError("Attempt to multiply incompatible matrices"); - for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i) - dims.emplace_back(lhsShape[i]); - dims.emplace_back(rhsShape[1]); - } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { - // (M x N) MATMUL (s1 x s2 x... x sK x N x P) - // => - // (s1 x s2 x... x sK x M x P) - - // Check legality of matrix multiplication. - unsigned rhsRank = rhsShape.size(); - if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && - lhsShape[1] != rhsShape[rhsRank - 2]) - return emitError("Attempt to multiply incompatible matrices"); - for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) - dims.emplace_back(rhsShape[i]); - dims.emplace_back(lhsShape[0]); - dims.emplace_back(rhsShape[rhsRank - 1]); - } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { - // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) - // => - // (u1 x u2 x... x uK x M x P) - - // Check legality of matrix multiplication. - unsigned lhsRank = lhsShape.size(); - unsigned rhsRank = rhsShape.size(); - if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && - lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) - return emitError("Attempt to multiply incompatible matrices"); - // Check and perform broadcasting for the shapes. - SmallVector lhsBcastShape; - for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) - lhsBcastShape.emplace_back(lhsShape[i]); - SmallVector rhsBcastShape; - for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) - rhsBcastShape.emplace_back(rhsShape[i]); - if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) - return emitError("Broadcasted dimensions are incompatible"); - dims.emplace_back(lhsShape[lhsRank - 2]); - dims.emplace_back(rhsShape[rhsRank - 1]); - } else { - // This case covers all remaining combinations of 1 and 2-D matrices. - int64_t lhsDim = lhsShape[0]; - int64_t rhsDim = rhsShape[0]; - if (lhsShape.size() > 1) { - lhsDim = lhsShape[1]; - dims.emplace_back(lhsShape[0]); - } +//===----------------------------------------------------------------------===// +// MatMulIntegerOp +//===----------------------------------------------------------------------===// - // Check legality of matrix multiplication. - if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) - return emitError("Attempt to multiply incompatible matrices"); - if (rhsShape.size() > 1) - dims.emplace_back(rhsShape[1]); - } +LogicalResult ONNXMatMulIntegerOp::inferShapes( + std::function doShapeInference) { + // Cannot infer shape if no shape exists. + if (!A().getType().isa() || + !B().getType().isa()) + return success(); - getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); - return success(); + auto elementType = getResult().getType().cast().getElementType(); + return shapeHelperInferShapes(*this, elementType); } // GemmOp @@ -4529,11 +4428,6 @@ LogicalResult ONNXLpPoolOp::inferShapes( return emitError(NOT_IMPLEMENTED_MESSAGE); } -LogicalResult ONNXMatMulIntegerOp::inferShapes( - std::function doShapeInference) { - return emitError(NOT_IMPLEMENTED_MESSAGE); -} - LogicalResult ONNXMaxPoolOp::inferShapes( std::function doShapeInference) { return emitError(NOT_IMPLEMENTED_MESSAGE); diff --git a/src/Dialect/ONNX/ShapeInference/MatMul.cpp b/src/Dialect/ONNX/ShapeInference/GenericMatMul.cpp similarity index 79% rename from src/Dialect/ONNX/ShapeInference/MatMul.cpp rename to src/Dialect/ONNX/ShapeInference/GenericMatMul.cpp index 762a6bb1ec..e89a29ab33 100644 --- a/src/Dialect/ONNX/ShapeInference/MatMul.cpp +++ b/src/Dialect/ONNX/ShapeInference/GenericMatMul.cpp @@ -2,41 +2,49 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===------------ MatMul.cpp - Shape Inference for MatMul Op --------------===// +//===--------- GenericMatMul.cpp - Shape Inference for matmul Ops ---------===// // -// This file implements shape inference for the ONNX MatMul Operator. +// This file implements shape inference for the ONNX MatMul, QLinearMatMul, and +// MatMulInteger Operators. // //===----------------------------------------------------------------------===// +#include +#include + #include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp" using namespace mlir; namespace onnx_mlir { -ONNXMatMulOpShapeHelper::ONNXMatMulOpShapeHelper( - ONNXMatMulOp *newOp, IndexExprScope *inScope) - : ONNXOpShapeHelper( - newOp, newOp->getOperation()->getNumResults(), inScope), - aDims(), bDims(), aPadDims(), bPadDims() {} +template +std::pair matMulInputs(OpAdaptor &operandAdaptor) { + Value A = operandAdaptor.A(); + Value B = operandAdaptor.B(); + return std::pair(A, B); +} -ONNXMatMulOpShapeHelper::ONNXMatMulOpShapeHelper(ONNXMatMulOp *newOp, - OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, - ArrayValueIndexCapture::LoadVal fLoadVal, IndexExprScope *inScope) - : ONNXOpShapeHelper(newOp, - newOp->getOperation()->getNumResults(), rewriter, fGetDenseVal, - fLoadVal, inScope), - aDims(), bDims(), aPadDims(), bPadDims() {} +template <> +std::pair matMulInputs( + ONNXQLinearMatMulOpAdaptor &operandAdaptor) { + Value A = operandAdaptor.a(); + Value B = operandAdaptor.b(); + return std::pair(A, B); +} + +template +LogicalResult ONNXGenericMatMulOpShapeHelper::computeShape( + typename OP_TYPE::Adaptor operandAdaptor) { -LogicalResult ONNXMatMulOpShapeHelper::computeShape( - ONNXMatMulOpAdaptor operandAdaptor) { // Shape inference indicated by passing a null rewriter pointer. // Output dims of result. DimsExpr outputDims; // Get info. - Value A = operandAdaptor.A(); - Value B = operandAdaptor.B(); + Value A; + Value B; + std::tie(A, B) = matMulInputs(operandAdaptor); MemRefBoundsIndexCapture ABounds(A); MemRefBoundsIndexCapture BBounds(B); @@ -93,7 +101,7 @@ LogicalResult ONNXMatMulOpShapeHelper::computeShape( } else if (aDims[i].isLiteral() && bDims[i].isLiteral()) { // No broadcast, both literals, make sure they have the same value. if (aDims[i].getLiteral() != bDims[i].getLiteral()) - return op->emitError("Incompatible size detected"); + return this->op->emitError("Incompatible size detected"); outputDims.emplace_back(aDims[i]); } else if (aDims[i].isLiteral()) { // A dim is a literal; use it here for output and b, since b @@ -119,7 +127,7 @@ LogicalResult ONNXMatMulOpShapeHelper::computeShape( // And test the K dimensions. if (aDims[aK].isLiteral() && bDims[bK].isLiteral()) { if (aDims[aK].getLiteral() != bDims[bK].getLiteral()) - return op->emitError("reduction dimension must be the same"); + return this->op->emitError("reduction dimension must be the same"); } else if (aDims[aK].isLiteral()) { // Save aK dims into bK dims, in case bK dims was runtime bDims[bK] = aDims[aK]; @@ -137,8 +145,12 @@ LogicalResult ONNXMatMulOpShapeHelper::computeShape( assert(outputDims.empty() && "1-D x 1-D results in scalar"); } // Save the final result. - setOutputDims(outputDims); + this->setOutputDims(outputDims); return success(); } +template struct ONNXGenericMatMulOpShapeHelper; +template struct ONNXGenericMatMulOpShapeHelper; +template struct ONNXGenericMatMulOpShapeHelper; + } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp index dde7531bd9..0a1a068f02 100644 --- a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp +++ b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp @@ -534,6 +534,8 @@ template struct ONNXOpShapeHelper; template struct ONNXOpShapeHelper; template struct ONNXOpShapeHelper; template struct ONNXOpShapeHelper; +template struct ONNXOpShapeHelper; +template struct ONNXOpShapeHelper; template struct ONNXOpShapeHelper; template struct ONNXOpShapeHelper; template struct ONNXOpShapeHelper; diff --git a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp index c6e3f4d3dc..22f94c43dc 100644 --- a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp +++ b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp @@ -314,15 +314,22 @@ struct ONNXGemmOpShapeHelper : public ONNXOpShapeHelper { int cRank; // Dim of the original C (not padding dims by 1). }; -// Shape for MatMulOp. -struct ONNXMatMulOpShapeHelper : public ONNXOpShapeHelper { - ONNXMatMulOpShapeHelper( - mlir::ONNXMatMulOp *newOp, IndexExprScope *inScope = nullptr); - ONNXMatMulOpShapeHelper(mlir::ONNXMatMulOp *newOp, mlir::OpBuilder *rewriter, +template +struct ONNXGenericMatMulOpShapeHelper : public ONNXOpShapeHelper { + ONNXGenericMatMulOpShapeHelper( + OP_TYPE *newOp, IndexExprScope *inScope = nullptr) + : ONNXOpShapeHelper( + newOp, newOp->getOperation()->getNumResults(), inScope), + aDims(), bDims(), aPadDims(), bPadDims() {} + ONNXGenericMatMulOpShapeHelper(OP_TYPE *newOp, mlir::OpBuilder *rewriter, ArrayValueIndexCapture::GetDenseVal fGetDenseVal, ArrayValueIndexCapture::LoadVal fLoadVal, - IndexExprScope *inScope = nullptr); - mlir::LogicalResult computeShape(mlir::ONNXMatMulOpAdaptor operandAdaptor); + IndexExprScope *inScope = nullptr) + : ONNXOpShapeHelper(newOp, + newOp->getOperation()->getNumResults(), rewriter, fGetDenseVal, + fLoadVal, inScope), + aDims(), bDims(), aPadDims(), bPadDims() {} + mlir::LogicalResult computeShape(typename OP_TYPE::Adaptor operandAdaptor); // Additional data for MatMulOp: output = a & b. llvm::SmallVector aDims, bDims; // Dim after applying padding. @@ -330,6 +337,13 @@ struct ONNXMatMulOpShapeHelper : public ONNXOpShapeHelper { bPadDims; // When true, that dim was padded. }; +using ONNXMatMulOpShapeHelper = + ONNXGenericMatMulOpShapeHelper; +using ONNXMatMulIntegerOpShapeHelper = + ONNXGenericMatMulOpShapeHelper; +using ONNXQLinearMatMulOpShapeHelper = + ONNXGenericMatMulOpShapeHelper; + // Shape for PadOp. struct ONNXPadOpShapeHelper : public ONNXOpShapeHelper { ONNXPadOpShapeHelper( diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 11daeccf42..8a18b3175c 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -255,6 +255,33 @@ func.func @test_matmul_10(%arg0 : tensor, %arg1 : tensor<32xf32>) - // ----- +/// QLinearMatMul + +func.func @test_qlinearmatmul_1(%arg0: tensor<2x2x4xui8>, %arg1: tensor<1xf32>, %arg2: tensor<1xui8>, %arg3: tensor<2x4x3xui8>, %arg4: tensor<1xf32>, %arg5: tensor<1xui8>, %arg6: tensor<1xf32>, %arg7: tensor<1xui8>) -> tensor<*xui8> { + %0 = "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2x2x4xui8>, tensor<1xf32>, tensor<1xui8>, tensor<2x4x3xui8>, tensor<1xf32>, tensor<1xui8>, tensor<1xf32>, tensor<1xui8>) -> tensor<*xui8> + "func.return"(%0) : (tensor<*xui8>) -> () + + + // CHECK-LABEL: test_qlinearmatmul_1 + // CHECK: [[RES1:%.+]] = "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2x2x4xui8>, tensor<1xf32>, tensor<1xui8>, tensor<2x4x3xui8>, tensor<1xf32>, tensor<1xui8>, tensor<1xf32>, tensor<1xui8>) -> tensor<2x2x3xui8> + // CHECK: return [[RES1]] : tensor<2x2x3xui8> +} + +// ----- + +/// MatMulInteger + +func.func @test_matmulinteger_1(%arg0: tensor<4x3xui8>, %arg1: tensor<3x2xui8>, %arg2: tensor<1xui8>, %arg3: tensor<1xui8>) -> tensor<*xi32> { + %0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x3xui8>, tensor<3x2xui8>, tensor<1xui8>, tensor<1xui8>) -> tensor<*xi32> + return %0 : tensor<*xi32> + + // CHECK-LABEL: test_matmulinteger_1 + // CHECK: [[RES1:%.+]] = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x3xui8>, tensor<3x2xui8>, tensor<1xui8>, tensor<1xui8>) -> tensor<4x2xi32> + // CHECK: return [[RES1]] : tensor<4x2xi32> +} + +// ----- + //===----------------------------------------------------------------------===// /// Test shape inference for Conv (first with no bias) operation and all its attributes. //===----------------------------------------------------------------------===//