Skip to content

Commit

Permalink
onnx: fix checks for axes and steps inputs of Slice operator
Browse files Browse the repository at this point in the history
The ONNX Spec for the Slice operator allows the `starts` and `ends`
inputs to have fewer indices that the dimensions of the `data` tensor
(see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
expects these inputs to be as many as the `data` tensor's dimensions.

More precisely, the spec requires that the `starts` and `ends` inputs
are only as long as the `axes` input, but since the `axes` input is
optional, the default type for the `axes` input has to match the type
for the `starts` and `ends` inputs.  Moreover, the number of indices in
the `steps` input also has to match those in the `axes` inputs (instad
of matching the dimensions of the `data` input).

This patch fixes the checks in the TorchOnnxToTorch conversion so that
they match the ONNX spec.
  • Loading branch information
Ashay Rane committed Feb 1, 2024
1 parent c7d7d7f commit 4bb7146
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
22 changes: 7 additions & 15 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1436,18 +1436,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return failure();
}
} else {
// The default axes value is the range from 0 to the number of
// dimensions
// The default axes value is the range from 0 to the size of first
// dimension of `starts` and `ends`.
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto defaultAxesType = Torch::ValueTensorType::get(
context, ArrayRef<int64_t>{operandTy.getRank()},
rewriter.getIntegerType(64, /*signed*/ 1));
Value arangeLength = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
operandTy.getRank()));
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
axes = rewriter.create<Torch::AtenArangeOp>(
loc, defaultAxesType, arangeLength, none, none, none, none);
loc, startsTorchTy, arangeLength, none, none, none, none);
}

// Binding `steps` from its arguments or through a default value
Expand All @@ -1458,22 +1454,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
}
} else {
// The default `steps` value is a 1d tensor filled with ones with a
// size of the dimension of the operand
// size equal to the size of `starts` and `ends`.
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
auto defaultStepsType = Torch::ValueTensorType::get(
context, ArrayRef<int64_t>{operandTy.getRank()},
rewriter.getIntegerType(64, /*signed*/ 1));
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
operandTy.getRank()));
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
sizeStepInput);
steps = rewriter.create<Torch::AtenOnesOp>(
loc, defaultStepsType, sizeStepsInput, none, none, none, none);
loc, startsTorchTy, sizeStepsInput, none, none, none, none);
}

if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
Expand Down
25 changes: 25 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,31 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3

// -----

// CHECK-LABEL: @test_slice_default_axes_and_steps
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[20,10,5],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64>

// CHECK: %[[ZERO0:.*]] = torch.constant.int 0
// CHECK: %[[ZERO1:.*]] = torch.constant.int 0
// CHECK: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[SELECT2:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ITEM2]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>

func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
%0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32>
return %0 : !torch.vtensor<[20,10,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_slice_default_steps
func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
//CHECK: %[[NONE:.*]] = torch.constant.none
Expand Down

0 comments on commit 4bb7146

Please sign in to comment.