Skip to content

Commit

Permalink
Add Shape op verifier (#1711)
Browse files Browse the repository at this point in the history
* Add Shape op verifier

Signed-off-by: Philip Lassen <[email protected]>
  • Loading branch information
philass authored Sep 20, 2022
1 parent b6a17f6 commit 51dcbf9
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
| **SequenceErase** | |unsupported | |
| **SequenceInsert** |11 |Does not support unranked sequence element. | |
| **SequenceLength** | |unsupported | |
| **Shape** |13 | | |
| **Shape** |15 |Does not support start and end attributes. | |
| **Shrink** | |unsupported | |
| **Sigmoid** |13 | | |
| **Sign** |13 | | |
Expand Down
12 changes: 12 additions & 0 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3436,6 +3436,18 @@ LogicalResult ONNXShapeOp::inferShapes(
ONNXShapeOpAdaptor>(*this, elementType);
}

LogicalResult ONNXShapeOp::verify() {
if (!data().getType().isa<RankedTensorType>())
return success();
ONNXShapeOpAdaptor operandAdaptor(*this);
int64_t start;
int64_t end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);
if (start > end)
return emitOpError() << "Start: " << start << " is after End: " << end;
return success();
}

//===----------------------------------------------------------------------===//
// Size
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5609,6 +5609,7 @@ def ONNXShapeOp:ONNX_Op<"Shape",
return {4};
}
}];
let hasVerifier = 1;
}

def ONNXShrinkOp:ONNX_Op<"Shrink",
Expand Down
2 changes: 2 additions & 0 deletions src/Dialect/ONNX/ShapeInference/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ LogicalResult ONNXShapeOpShapeHelper::computeShape(
int64_t end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);

assert(start <= end && "Start must not be greater than end");

// Output is the actual number of values (1D)
dimsForOutput().emplace_back(LiteralIndexExpr(end - start));

Expand Down
1 change: 1 addition & 0 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def get_test_models():
#"test_sequence_insert_at_back_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== Shape
# ==LIM== Does not support start and end attributes.
"test_shape_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_shape_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

Expand Down
8 changes: 8 additions & 0 deletions test/mlir/onnx/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ func.func @test_scatterelements_verifier_2(%arg0: tensor<2x2xf32>, %arg1: tensor

// -----

func.func @test_shape_to_dim_positive_axis_verifier(%arg0: tensor<?x256x?xi64>) -> tensor<2xi64> {
// expected-error @+1 {{'onnx.Shape' op Start: 2 is after End: 0}}
%0 = "onnx.Shape"(%arg0) {end = 0 : si64, start = -1 : si64} : (tensor<?x256x?xi64>) -> tensor<2xi64>
return %0 : tensor<2xi64>
}

// -----

func.func @test_logsoftmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> {
// expected-error @+1 {{onnx.LogSoftmax: 'axis' value is 3, accepted range is [-2, 1]}}
%1 = "onnx.LogSoftmax"(%arg0) {axis = 3 : si64} : (tensor<2x2xf32>) -> tensor<*xf32>
Expand Down
1 change: 1 addition & 0 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@
'ScatterND',
'SequenceEmpty',
'SequenceInsert',
'Shape',
'SpaceToDepth',
'Split',
'SplitToSequence',
Expand Down

0 comments on commit 51dcbf9

Please sign in to comment.