diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index 6eeef1e0c0..e6dfe3fc66 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -49,8 +49,6 @@ def GetNullIntegerAttr : NativeCodeCall<"IntegerAttr()">; def GetNullStringAttr : NativeCodeCall<"StringAttr()">; -def GetNoneStringAttr : NativeCodeCall<"$_builder.getStringAttr(\"none\")">; - // Create a unit constant that will be used as none input. def CreateUnitConstant : NativeCodeCall<"::onnx_mlir::createUnitConstant($_builder, $_loc)">; @@ -347,7 +345,7 @@ def ClipV12Pattern : Pat< // Express Scatter (deprecated) using ScatterElements. def ScatterPattern : Pat< (ONNXScatterOp $data, $indices, $updates, $axis), - (ONNXScatterElementsOp $data, $indices, $updates, $axis, (GetNoneStringAttr)) + (ONNXScatterElementsOp $data, $indices, $updates, $axis, (GetNullStringAttr)) >; #endif // ONNX_DECOMPOSE diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 49199b41c2..52fd95d277 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -346,7 +346,7 @@ func.func @test_scatter(%arg0: tensor<64x25600xf32>, %arg1: tensor<64x100xi64>, // CHECK-LABEL: func @test_scatter // CHECK-SAME: ([[PARAM_0:%.+]]: tensor<64x25600xf32>, [[PARAM_1:%.+]]: tensor<64x100xi64>, [[PARAM_2:%.+]]: tensor<64x100xf32>) -> tensor<*xf32> { - // CHECK-NEXT: [[RES:%.+]] = "onnx.ScatterElements"(%arg0, %arg1, %arg2) {axis = 1 : si64, reduction = "none"} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32> + // CHECK-NEXT: [[RES:%.+]] = "onnx.ScatterElements"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32> // CHECK-NEXT: return [[RES]] : tensor<*xf32> }