Skip to content

Commit

Permalink
Add decomposition for aten.flatten.using_ints
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok committed Aug 23, 2022
1 parent 01290d1 commit 756c196
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
64 changes: 64 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,68 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
};
} // namespace

// Decompose aten.flatten.using_ints into aten.view op.
namespace {
class DecomposeAtenFlattenUsingIntsOp
: public OpRewritePattern<AtenFlattenUsingIntsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
MLIRContext *context = op.getContext();
int64_t rank = getTensorRank(self);
if (rank < 0)
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");

int64_t start, end;
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start)) ||
!matchPattern(op.end_dim(), m_TorchConstantInt(&end))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: requires start and end dims to be constants");
}

SmallVector<Value, 4> newSizes;
if (rank == 0) {
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
newSizes.push_back(one);
} else {
start = toPositiveDim(start, rank);
end = toPositiveDim(end, rank);

if (start > end) {
return rewriter.notifyMatchFailure(
op, "expected end dim larger than start dim");
}

newSizes.reserve(rank - end + start);
for (size_t k = 0; k < start; ++k) {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(k));
newSizes.push_back(
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
}
Value flattenDimSize =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
newSizes.push_back(flattenDimSize);
for (size_t k = end + 1; k < rank; ++k) {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(k));
newSizes.push_back(
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
}
}
Value newSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), newSizes);
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
newSizeList);
return success();
}
};
} // namespace

// Decompose aten.expand into aten.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
Expand Down Expand Up @@ -2497,6 +2559,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenRepeatOp>();
patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenFlattenUsingIntsOp>(context);
target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<DecomposeAtenWhereScalarOp>(context);
target.addIllegalOp<AtenWhereScalarOp>();
patterns.add<DecomposeAtenWhereScalarOtherOp>(context);
Expand Down
4 changes: 2 additions & 2 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
# ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: [],
OutputType.LINALG_ON_TENSORS: [],
OutputType.TOSA: ['torch.aten.flatten.using_ints',],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
OutputType.MHLO: [],
}

Expand Down
22 changes: 20 additions & 2 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,10 @@ func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] :
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[CST:.*]]-1 = torch.constant.int -1
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[CST]]-1 : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLATTEN:.*]] = torch.aten.view %[[INP]], %[[T0]] :
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] :
// CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[],si64>
Expand Down Expand Up @@ -1332,3 +1334,19 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
%0 = torch.aten.std.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,5],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32>
return %0 : !torch.vtensor<[3,4,1],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[INT]]-1 : (!torch.int) -> !torch.list<int>
// CHECK: %[[T1:.*]] = torch.aten.view %[[ARG0]], %[[T0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
// CHECK: return %[[T1]] : !torch.vtensor<[?],f32>
func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
return %1 : !torch.vtensor<[?],f32>
}

0 comments on commit 756c196

Please sign in to comment.