From 2fac04613e5c5f9b638e45f50151ec7c4c3db954 Mon Sep 17 00:00:00 2001 From: Philip Lassen Date: Tue, 27 Sep 2022 16:33:30 -0700 Subject: [PATCH] Use NullStringAttr insead of "none" string Signed-off-by: Philip Lassen --- src/Transform/ONNX/Decompose.td | 4 +--- test/mlir/onnx/onnx_decompose.mlir | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index 6eeef1e0c02..e6dfe3fc669 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -49,8 +49,6 @@ def GetNullIntegerAttr : NativeCodeCall<"IntegerAttr()">; def GetNullStringAttr : NativeCodeCall<"StringAttr()">; -def GetNoneStringAttr : NativeCodeCall<"$_builder.getStringAttr(\"none\")">; - // Create a unit constant that will be used as none input. def CreateUnitConstant : NativeCodeCall<"::onnx_mlir::createUnitConstant($_builder, $_loc)">; @@ -347,7 +345,7 @@ def ClipV12Pattern : Pat< // Express Scatter (deprecated) using ScatterElements. def ScatterPattern : Pat< (ONNXScatterOp $data, $indices, $updates, $axis), - (ONNXScatterElementsOp $data, $indices, $updates, $axis, (GetNoneStringAttr)) + (ONNXScatterElementsOp $data, $indices, $updates, $axis, (GetNullStringAttr)) >; #endif // ONNX_DECOMPOSE diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 49199b41c26..52fd95d2773 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -346,7 +346,7 @@ func.func @test_scatter(%arg0: tensor<64x25600xf32>, %arg1: tensor<64x100xi64>, // CHECK-LABEL: func @test_scatter // CHECK-SAME: ([[PARAM_0:%.+]]: tensor<64x25600xf32>, [[PARAM_1:%.+]]: tensor<64x100xi64>, [[PARAM_2:%.+]]: tensor<64x100xf32>) -> tensor<*xf32> { - // CHECK-NEXT: [[RES:%.+]] = "onnx.ScatterElements"(%arg0, %arg1, %arg2) {axis = 1 : si64, reduction = "none"} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32> + // CHECK-NEXT: [[RES:%.+]] = "onnx.ScatterElements"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<64x25600xf32>, tensor<64x100xi64>, tensor<64x100xf32>) -> tensor<*xf32> // CHECK-NEXT: return [[RES]] : tensor<*xf32> }