From 57265bf9f15aeac424a445b631f8f80dab8aa682 Mon Sep 17 00:00:00 2001 From: Chen Tong Date: Mon, 11 Mar 2024 21:19:50 -0400 Subject: [PATCH] int32 pass Signed-off-by: Chen Tong --- src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp | 2 ++ src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp | 2 ++ .../ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp | 10 ++++++++++ test/backend/inference_backend.py | 13 +++++++++---- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp index 8ac952067b..7ba3742b17 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -74,6 +74,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { auto indicesType = indices.getType().cast(); ArrayRef indicesShape = indicesType.getShape(); int64_t indicesLastDim = indicesShape[indicesRank - 1]; + // ToFix: Handle case in which indicesLastDim is kDynamic. + // Currently, such case is detected by ONNXPreKrnlVerifyPass. assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) && "indices.shape[-1] must be in the range [1, dataRank - b]"); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp index c0a2c343a2..eeec664491 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp @@ -50,6 +50,8 @@ LogicalResult ONNXGatherNDOpShapeHelper::computeShape() { assert(b >= 0 && "batch_dim should not be negative"); assert(b < std::min(dataRank, indicesRank) && "batch_dims must be smaller than the min(dataRank, indicesRank)"); + // ToFix: Handle the case when indicesLastDim is dynamic and the rank + // of output tensor is known. assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) && "indices.shape[-1] must be in the range [1, dataRank - b]"); diff --git a/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp b/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp index 20b840a847..ebd2978d0c 100644 --- a/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp +++ b/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp @@ -25,6 +25,7 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Interface/ShapeInferenceOpInterface.hpp" #include "src/Pass/Passes.hpp" +#include "src/Support/TypeUtilities.hpp" using namespace mlir; @@ -70,6 +71,15 @@ class ONNXPreKrnlVerifyPass : public mlir::PassWrapper() && !ty.isa()) { op.emitError("not ranked"); return failure(); + } else if (ONNXGatherNDOp gatherNDOp = llvm::dyn_cast(op)) { + Value indices = gatherNDOp.getIndices(); + Type indicesType = indices.getType(); + ArrayRef indicesShape = onnx_mlir::getShape(indices.getType()); + int64_t indicesRank = onnx_mlir::getRank(indicesType); + if (indicesShape[indicesRank - 1] == ShapedType::kDynamic) { + op.emitError("last dim of indices in GatherND is dynamic"); + return failure(); + } } } return success(); diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index b73c571ccf..ae8bfc1b41 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -1092,22 +1092,27 @@ def get_test_models(): }, # ==OP== GatherND # ==MIN== 11 + # According to onnx specification: + # output tensor of rank q + r - indices_shape[-1] - 1 - b + # onnx-mlir can not infer the shape if the last dim of indices + # input is dynamic. Therefore, {0: {-1}} is used for DYNAMIC_SHAPE "test_gathernd_example_int32_cpu": { STATIC_SHAPE: {}, - # Issue #2639: Dynamic test fails. Need to be fixed. - # DYNAMIC_SHAPE: {-1: {-1}}, + DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, "test_gathernd_example_float32_cpu": { STATIC_SHAPE: {}, # Issue #2639: Dynamic test fails. Need to be fixed. - # DYNAMIC_SHAPE: {-1: {-1}}, + # Dynamic test failed due to result difference + # DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, "test_gathernd_example_int32_batch_dim1_cpu": { STATIC_SHAPE: {}, # Issue #2639: Dynamic test fails. Need to be fixed. - # DYNAMIC_SHAPE: {-1: {-1}}, + # Dynamic test failed due to result difference + # DYNAMIC_SHAPE: {0: {-1}}, CONSTANT_INPUT: {-1}, }, # ==OP== Gelu