Skip to content

Commit

Permalink
[ONNX] Add OnnxToTorch support for SequenceMap (llvm#3535)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored Jul 17, 2024
1 parent fde286f commit f0ce1e9
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
81 changes: 81 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3359,6 +3359,87 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOp(binder.op, inputSequence);
return success();
});
patterns.onOp(
"SequenceMap", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<Value> operands;
Torch::ListType resultType;
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
binder.tensorListResultType(resultType)) {
return failure();
}

Region *bodyRegion;
if (binder.getRegionAtIndex(bodyRegion, 0)) {
return rewriter.notifyMatchFailure(binder.op,
"Failed getting Body Region");
}

// construct an empty list, append results through the loop
auto resultTensorType =
dyn_cast<Torch::ValueTensorType>(resultType.getContainedType());
Value shapeList = createConstantIntList(binder, rewriter,
resultTensorType.getSizes());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
binder.op->getLoc(), resultType.getContainedType(), shapeList,
/*dtype=*/cstNone, /*layout=*/cstNone, /*device=*/cstNone,
/*pinMemory=*/cstNone, /*memoryFormat=*/cstNone);
Value result = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), resultType, llvm::SmallVector<Value>{self});

// create a for-like primLoopOp
// with the length of sequence as max iter_num
Value len = rewriter.create<Torch::AtenLenTOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[0]);
auto cstTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
mlir::ImplicitLocOpBuilder b(binder.getLoc(), rewriter);
auto loop =
b.create<Torch::PrimLoopOp>(resultType, len, cstTrue, result);
rewriter.cloneRegionBefore(*bodyRegion, loop.getRegion(),
loop.getRegion().begin());

// primLoopOp loopBody expects torch.int as first arg
// remove inputs from the region and use it from outside
loop.getRegion().front().insertArgument(0U, resultType,
binder.getLoc());
Value sequenceArg = loop.getRegion().front().getArgument(0);
loop.getRegion().front().insertArgument(
0U, rewriter.getType<Torch::IntType>(), binder.getLoc());
Value indexArg = loop.getRegion().front().getArgument(0);

// get sequence[i] (and addtionalInput[i]) in each iteration
rewriter.setInsertionPointToStart(&loop.getRegion().front());
for (size_t i = 0; i < operands.size(); i++) {
Value argInput = loop.getRegion().front().getArgument(2);
if (isa<Torch::ListType>(operands[i].getType())) {
auto tensorType = dyn_cast<Torch::ValueTensorType>(
dyn_cast<Torch::ListType>(operands[i].getType())
.getContainedType());
Value item = rewriter.create<Torch::Aten__Getitem__TOp>(
binder.getLoc(), tensorType, operands[i], indexArg);
argInput.replaceAllUsesWith(item);
} else {
argInput.replaceAllUsesWith(operands[i]);
}
loop.getRegion().eraseArgument(2);
}

// replace terminator
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = loop.getRegion().front().getTerminator();
rewriter.setInsertionPoint(terminator);
// update sequence input
auto terminatorOperands = terminator->getOperands();
Value append = rewriter.create<Torch::AtenAppendTOp>(
binder.getLoc(), resultType, sequenceArg, terminatorOperands[0]);
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
terminator, cstTrue, append);

rewriter.replaceOp(binder.op, loop);
return success();
});
patterns.onOp(
"Upsample", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
118 changes: 118 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,124 @@ func.func @test_sequence_empty() -> !torch.list<vtensor<[],f32>> attributes {tor

// -----

// CHECK-LABEL: func.func @test_sequence_map_add
func.func @test_sequence_map_add(%arg0: !torch.list<vtensor<[2,3],f32>>, %arg1: !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>>
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[2,3],f32>> -> !torch.int
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[2,3],f32>>):
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[2,3],f32>>, !torch.int -> !torch.vtensor<[2,3],f32>
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[SAMPLE]], %arg1, %[[C1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[ADD]] : !torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32> -> !torch.list<vtensor<[2,3],f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[2,3],f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[2,3],f32>>) -> !torch.list<vtensor<[2,3],f32>>
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[2,3],f32>>
%0 = torch.operator "onnx.SequenceMap"(%arg0, %arg1) : (!torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>> {
^bb0(%arg2: !torch.vtensor<[2,3],f32>, %arg3: !torch.vtensor<[2,3],f32>):
%1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
torch.operator_terminator %1 : !torch.vtensor<[2,3],f32>
}
return %0 : !torch.list<vtensor<[2,3],f32>>
}

// -----

// CHECK-LABEL: func.func @test_sequence_map_add_sequence_variadic
func.func @test_sequence_map_add_sequence_variadic(%arg0: !torch.list<vtensor<[?],f32>>, %arg1: !torch.list<vtensor<[?],f32>>, %arg2: !torch.vtensor<[?],f32>) -> !torch.list<vtensor<[?],f32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NEG1:.*]] = torch.constant.int -1
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[NEG1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[?],f32>) -> !torch.list<vtensor<[?],f32>>
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?],f32>> -> !torch.int
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[?],f32>>):
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[?],f32>>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[ADDITION_INPUT:.*]] = torch.aten.__getitem__.t %arg1, %[[ITER_NUM]] : !torch.list<vtensor<[?],f32>>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[SAMPLE]], %[[ADDITION_INPUT]], %[[C1]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[ADD_0:.*]] = torch.aten.add.Tensor %[[ADD]], %arg2, %[[C1_0]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[ADD_0]] : !torch.list<vtensor<[?],f32>>, !torch.vtensor<[?],f32> -> !torch.list<vtensor<[?],f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[?],f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[?],f32>>) -> !torch.list<vtensor<[?],f32>>
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[?],f32>>
%0 = torch.operator "onnx.SequenceMap"(%arg0, %arg1, %arg2) : (!torch.list<vtensor<[?],f32>>, !torch.list<vtensor<[?],f32>>, !torch.vtensor<[?],f32>) -> !torch.list<vtensor<[?],f32>> {
^bb0(%arg3: !torch.vtensor<[?],f32>, %arg4: !torch.vtensor<[?],f32>, %arg5: !torch.vtensor<[?],f32>):
%1 = torch.operator "onnx.Add"(%arg3, %arg4) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32>
%2 = torch.operator "onnx.Add"(%1, %arg5) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32>
torch.operator_terminator %2 : !torch.vtensor<[?],f32>
}
return %0 : !torch.list<vtensor<[?],f32>>
}

// -----

// CHECK-LABEL: func.func @test_sequence_map_identity
func.func @test_sequence_map_identity(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[?,?,?],f32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NEG1:.*]] = torch.constant.int -1
// CHECK: %[[NEG1_0:.*]] = torch.constant.int -1
// CHECK: %[[NEG1_1:.*]] = torch.constant.int -1
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[NEG1]], %[[NEG1_0]], %[[NEG1_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[?,?,?],f32>) -> !torch.list<vtensor<[?,?,?],f32>>
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[?,?,?],f32>>):
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[?,?,?],f32>>, !torch.int -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[CLONE:.*]] = torch.aten.clone %[[SAMPLE]], %[[NONE_0]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[CLONE]] : !torch.list<vtensor<[?,?,?],f32>>, !torch.vtensor<[?,?,?],f32> -> !torch.list<vtensor<[?,?,?],f32>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[?,?,?],f32>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[?,?,?],f32>>
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[?,?,?],f32>>
%0 = torch.operator "onnx.SequenceMap"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[?,?,?],f32>> {
^bb0(%arg1: !torch.vtensor<[?,?,?],f32>):
%1 = torch.operator "onnx.Identity"(%arg1) : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32>
torch.operator_terminator %1 : !torch.vtensor<[?,?,?],f32>
}
return %0 : !torch.list<vtensor<[?,?,?],f32>>
}

// -----

// CHECK-LABEL: func.func @test_sequence_map_extract_shapes
func.func @test_sequence_map_extract_shapes(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[3],si64>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[SHAPE]] = torch.prim.ListConstruct %[[C3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor<[3],si64>>
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[3],si64>>):
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[?,?,?],f32>>, !torch.int -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SHAPE_0:.*]] = torch.aten._shape_as_tensor %[[SAMPLE]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[3],si64>
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[SHAPE_0]] : !torch.list<vtensor<[3],si64>>, !torch.vtensor<[3],si64> -> !torch.list<vtensor<[3],si64>>
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[3],si64>>)
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[3],si64>>) -> !torch.list<vtensor<[3],si64>>
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[3],si64>>
%0 = torch.operator "onnx.SequenceMap"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[3],si64>> {
^bb0(%arg1: !torch.vtensor<[?,?,?],f32>):
%1 = torch.operator "onnx.Shape"(%arg1) : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3],si64>
torch.operator_terminator %1 : !torch.vtensor<[3],si64>
}
return %0 : !torch.list<vtensor<[3],si64>>
}

// -----

// CHECK-LABEL: func.func @test_upsample_nearest
func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
Expand Down

0 comments on commit f0ce1e9

Please sign in to comment.