Skip to content

Commit

Permalink
Add decompositions for v11 split, squeeze, and unsqueeze (#1702)
Browse files Browse the repository at this point in the history
* Add decompositions for v11 split, squeeze, and unsqueeze

Co-authored-by: Roberto DiCecco <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>

* Fix lit test

Signed-off-by: Philip Lassen <[email protected]>

* Add lit tests

Signed-off-by: Philip Lassen <[email protected]>

* Delete unneccesary decomp for unsqueeze

Signed-off-by: Philip Lassen <[email protected]>

Signed-off-by: Philip Lassen <[email protected]>
Co-authored-by: Roberto DiCecco <[email protected]>
  • Loading branch information
philass and rdicecco authored Sep 28, 2022
1 parent 9fcbbf3 commit 35b424c
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,12 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXScalerOp>();
target.addIllegalOp<ONNXScatterOp>();
target.addIllegalOp<ONNXSequenceConstructOp>();
target.addIllegalOp<ONNXSplitV11Op>();
target.addIllegalOp<ONNXSqueezeV11Op>();
target.addIllegalOp<ONNXUpsampleOp>();
target.addIllegalOp<ONNXUpsampleV9Op>();
target.addIllegalOp<ONNXUpsampleV7Op>();
target.addIllegalOp<ONNXUnsqueezeV11Op>();

RewritePatternSet patterns(context);
populateWithGenerated(patterns);
Expand Down
35 changes: 32 additions & 3 deletions src/Transform/ONNX/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def ONNXDataType : NativeCodeCall<
def ReduceL1OpPattern1
: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp(ONNXAbsOp $oprd),
(CreateUnitConstant), $keepdims, (noop_with_empty_axes)),
(CreateUnitConstant), $keepdims, (noop_with_empty_axes)),
[(AttributeIsNull:$axes)], (addBenefit 1)>;

def ReduceL1OpPattern2
Expand All @@ -119,7 +119,7 @@ def ReduceL2OpPattern
def ReduceLogSumOpPattern1
: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp(ONNXReduceSumOp $oprd,
(CreateUnitConstant), $keepdims, (noop_with_empty_axes))),
(CreateUnitConstant), $keepdims, (noop_with_empty_axes))),
[(AttributeIsNull:$axes)], (addBenefit 1)>;

def ReduceLogSumOpPattern2
Expand Down Expand Up @@ -160,7 +160,7 @@ def ReduceLogSumExpOpPattern2
def ReduceSumSquareOpPattern1
: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd),
(CreateUnitConstant), $keepdims, (noop_with_empty_axes)),
(CreateUnitConstant), $keepdims, (noop_with_empty_axes)),
[(AttributeIsNull:$axes)], (addBenefit 1)>;

def ReduceSumSquareOpPattern2
Expand Down Expand Up @@ -342,6 +342,35 @@ def ClipV12Pattern : Pat<
(ONNXClipOp $x, $min, $max)
>;

def SplitV11PatternNoAttr : Pat<
(ONNXSplitV11Op $x, $axis, $split),
(ONNXSplitOp $x, (CreateUnitConstant), $axis),
[(AttributeIsNull:$split)], (addBenefit 1)
>;

def SplitV11Pattern : Pat<
(ONNXSplitV11Op $x, $axis, $split),
(ONNXSplitOp $x, (ONNXConstantOpFromDenseAttr(createDenseArrayAttr $split)), $axis),
[], (addBenefit 0)
>;

def SqueezeV11PatternNoAttr : Pat<
(ONNXSqueezeV11Op $x, $axes),
(ONNXSqueezeOp $x, (CreateUnitConstant)),
[(AttributeIsNull:$axes)], (addBenefit 1)
>;

def SqueezeV11Pattern : Pat<
(ONNXSqueezeV11Op $x, $axes),
(ONNXSqueezeOp $x, (ONNXConstantOpFromDenseAttr(createDenseArrayAttr $axes))),
[], (addBenefit 0)
>;

def UnsqueezeV11Pattern : Pat<
(ONNXUnsqueezeV11Op $x, $axes),
(ONNXUnsqueezeOp $x, (ONNXConstantOpFromDenseAttr(createDenseArrayAttr $axes)))
>;

// Express Scatter (deprecated) using ScatterElements.
def ScatterPattern : Pat<
(ONNXScatterOp $data, $indices, $updates, $axis),
Expand Down
64 changes: 63 additions & 1 deletion test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func.func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: [[CST:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]], [[CST]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQUEEZE:%.+]] = "onnx.SqueezeV11"([[REDUCE_MAX]]) {axes = [1]} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[AXES:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK-NEXT: [[SQUEEZE:%.+]] = "onnx.Squeeze"([[REDUCE_MAX]], [[AXES]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
// CHECK-NEXT: [[RES:%.+]] = "onnx.Add"([[LOG]], [[SQUEEZE]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: return [[RES]] : tensor<*xf32>
}
Expand Down Expand Up @@ -340,6 +341,67 @@ func.func @test_clipv6(%arg0 : tensor<*xf32>) -> () {

// -----

func.func @test_splitV11(%arg0 : tensor<*xf32>) -> () {
%0 = "onnx.SplitV11"(%arg0) {axis = 1 : si64, split = [1]} : (tensor<*xf32>) -> tensor<*xf32>
return

// CHECK-LABEL: func @test_splitV11
// CHECK: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Split"(%arg0, %0) {axis = 1 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
// CHECK: return
}

// -----

func.func @test_splitV11_no_split(%arg0 : tensor<*xf32>) -> () {
%0 = "onnx.SplitV11"(%arg0) {axis = 1 : si64} : (tensor<*xf32>) -> tensor<*xf32>
return

// CHECK-LABEL: func @test_splitV11_no_split
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[VAR_1_:%.+]] = "onnx.Split"(%arg0, %0) {axis = 1 : si64} : (tensor<*xf32>, none) -> tensor<*xf32>
// CHECK: return
}


// -----

func.func @test_squeezeV11(%arg0 : tensor<*xf32>) -> () {
%0 = "onnx.SqueezeV11"(%arg0) {axes = [1]} : (tensor<*xf32>) -> tensor<*xf32>
return

// CHECK-LABEL: func @test_squeezeV11
// CHECK: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Squeeze"(%arg0, %0) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
// CHECK: return
}

// -----

func.func @test_squeezeV11_no_axes(%arg0 : tensor<*xf32>) -> () {
%0 = "onnx.SqueezeV11"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return

// CHECK-LABEL: func @test_squeezeV11_no_axes
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[VAR_1_:%.+]] = "onnx.Squeeze"(%arg0, %0) : (tensor<*xf32>, none) -> tensor<*xf32>
// CHECK: return
}

// -----

func.func @test_unsqueezeV11(%arg0 : tensor<*xf32>) -> () {
%0 = "onnx.UnsqueezeV11"(%arg0) {axes = [1]} : (tensor<*xf32>) -> tensor<*xf32>
return

// CHECK-LABEL: func @test_unsqueezeV11
// CHECK: [[VAR_0_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Unsqueeze"(%arg0, %0) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
// CHECK: return
}

// -----

func.func @test_scatter(%arg0: tensor<64x25600xf32>, %arg1: tensor<64x100xi64>, %arg2: tensor<64x100xf32>) -> tensor<*xf32> {
%0 = "onnx.Scatter"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
Expand Down

0 comments on commit 35b424c

Please sign in to comment.