diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 7dd99c72ab42..5357fc233e9d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2266,10 +2266,11 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - int64_t start; - int64_t end; + int64_t start, end, step; if (matchPattern(getStart(), m_TorchConstantInt(&start)) && - matchPattern(getEnd(), m_TorchConstantInt(&end)) + matchPattern(getEnd(), m_TorchConstantInt(&end)) && + matchPattern(getStep(), m_TorchConstantInt(&step)) + && step == 1 && start == 0 && end == std::numeric_limits::max()) return getOperand(0); diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index b4f9db5df4ef..8d11c640d7c9 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1913,6 +1913,27 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor< return %0 : !torch.vtensor<[4],f32> } +// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice +// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[?],f32> +// CHECK: return %[[ARG0]] : !torch.vtensor<[?],f32> +func.func @torch.aten.slice.tensor$fold_full_slice(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> { + %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int0 = torch.constant.int 0 + %0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32> + return %0 : !torch.vtensor<[?],f32> +} + +// CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step +// CHECK: torch.aten.slice.Tensor +func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> { + %int2 = torch.constant.int 2 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int0 = torch.constant.int 0 + %0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32> + return %0 : !torch.vtensor<[?],f32> +} + // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %int-1 = torch.constant.int -1 // CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>