Skip to content

Commit

Permalink
int32 pass
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Tong <[email protected]>
  • Loading branch information
chentong319 committed Mar 12, 2024
1 parent 7c88851 commit 57265bf
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern<ONNXGatherNDOp> {
auto indicesType = indices.getType().cast<ShapedType>();
ArrayRef<int64_t> indicesShape = indicesType.getShape();
int64_t indicesLastDim = indicesShape[indicesRank - 1];
// ToFix: Handle case in which indicesLastDim is kDynamic.
// Currently, such case is detected by ONNXPreKrnlVerifyPass.
assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
"indices.shape[-1] must be in the range [1, dataRank - b]");

Expand Down
2 changes: 2 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ LogicalResult ONNXGatherNDOpShapeHelper::computeShape() {
assert(b >= 0 && "batch_dim should not be negative");
assert(b < std::min(dataRank, indicesRank) &&
"batch_dims must be smaller than the min(dataRank, indicesRank)");
// ToFix: Handle the case when indicesLastDim is dynamic and the rank
// of output tensor is known.
assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
"indices.shape[-1] must be in the range [1, dataRank - b]");

Expand Down
10 changes: 10 additions & 0 deletions src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Interface/ShapeInferenceOpInterface.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;

Expand Down Expand Up @@ -70,6 +71,15 @@ class ONNXPreKrnlVerifyPass : public mlir::PassWrapper<ONNXPreKrnlVerifyPass,
} else if (!ty.isa<RankedTensorType>() && !ty.isa<NoneType>()) {
op.emitError("not ranked");
return failure();
} else if (ONNXGatherNDOp gatherNDOp = llvm::dyn_cast<ONNXGatherNDOp>(op)) {
Value indices = gatherNDOp.getIndices();
Type indicesType = indices.getType();
ArrayRef<int64_t> indicesShape = onnx_mlir::getShape(indices.getType());
int64_t indicesRank = onnx_mlir::getRank(indicesType);
if (indicesShape[indicesRank - 1] == ShapedType::kDynamic) {
op.emitError("last dim of indices in GatherND is dynamic");
return failure();
}
}
}
return success();
Expand Down
13 changes: 9 additions & 4 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,22 +1092,27 @@ def get_test_models():
},
# ==OP== GatherND
# ==MIN== 11
# According to onnx specification:
# output tensor of rank q + r - indices_shape[-1] - 1 - b
# onnx-mlir can not infer the shape if the last dim of indices
# input is dynamic. Therefore, {0: {-1}} is used for DYNAMIC_SHAPE
"test_gathernd_example_int32_cpu": {
STATIC_SHAPE: {},
# Issue #2639: Dynamic test fails. Need to be fixed.
# DYNAMIC_SHAPE: {-1: {-1}},
DYNAMIC_SHAPE: {0: {-1}},
CONSTANT_INPUT: {-1},
},
"test_gathernd_example_float32_cpu": {
STATIC_SHAPE: {},
# Issue #2639: Dynamic test fails. Need to be fixed.
# DYNAMIC_SHAPE: {-1: {-1}},
# Dynamic test failed due to result difference
# DYNAMIC_SHAPE: {0: {-1}},
CONSTANT_INPUT: {-1},
},
"test_gathernd_example_int32_batch_dim1_cpu": {
STATIC_SHAPE: {},
# Issue #2639: Dynamic test fails. Need to be fixed.
# DYNAMIC_SHAPE: {-1: {-1}},
# Dynamic test failed due to result difference
# DYNAMIC_SHAPE: {0: {-1}},
CONSTANT_INPUT: {-1},
},
# ==OP== Gelu
Expand Down

0 comments on commit 57265bf

Please sign in to comment.