Skip to content

Commit

Permalink
Fix SliceOp::fold to check step (introduced by #37)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jun 12, 2023
1 parent 3f8a1cd commit 828e99e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2266,10 +2266,11 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
int64_t start;
int64_t end;
int64_t start, end, step;
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
matchPattern(getEnd(), m_TorchConstantInt(&end))
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
matchPattern(getStep(), m_TorchConstantInt(&step))
&& step == 1
&& start == 0
&& end == std::numeric_limits<int64_t>::max())
return getOperand(0);
Expand Down
21 changes: 21 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,27 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
return %0 : !torch.vtensor<[4],f32>
}

// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[?],f32>
// CHECK: return %[[ARG0]] : !torch.vtensor<[?],f32>
func.func @torch.aten.slice.tensor$fold_full_slice(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> {
%int1 = torch.constant.int 1
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int0 = torch.constant.int 0
%0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32>
return %0 : !torch.vtensor<[?],f32>
}

// CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step
// CHECK: torch.aten.slice.Tensor
func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> {
%int2 = torch.constant.int 2
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int0 = torch.constant.int 0
%0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32>
return %0 : !torch.vtensor<[?],f32>
}

// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
Expand Down

0 comments on commit 828e99e

Please sign in to comment.