From c94f6e80f2b89c2d3febffbe1c786ddd1f764eb9 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Mon, 14 Nov 2022 13:45:28 -0800 Subject: [PATCH] Add parser for StableHlo_Dim --- stablehlo/dialect/StablehloAttrs.td | 15 +++++++++------ stablehlo/dialect/StablehloOps.cpp | 12 ++++++++++++ stablehlo/dialect/StablehloOps.h | 3 +++ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index cef42b3b07b..c02ef51c266 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -20,7 +20,10 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/IR/TensorEncoding.td" -def StableHLO_Dim : ArrayRefParameter<"int64_t", "Dimension">; +def StableHLO_Dim : ArrayRefParameter<"int64_t", "Dimension"> { + let parser = "mlir::stablehlo::parseIntArray($_parser)"; + let printer = "mlir::stablehlo::printIntArray($_printer, $_self)"; +} def StableHLO_ScatterDimensionNumbers : AttrDef { let cppNamespace = "::mlir::stablehlo"; @@ -98,11 +101,11 @@ def OutputOperandAlias : AttrDef { "int64_t":$operandIndex, StableHLO_Dim:$operandTupleIndices ); - let assemblyFormat = "`<` " - "`output_tuple_indices` `=` `[` $outputTupleIndices `]` `,`" - "`operand_index` `=` $operandIndex `,`" - "`operand_tuple_indices` `=` `[` $operandTupleIndices `]`" - "`>`"; + let assemblyFormat = [{ + `<` `output_tuple_indices` `=` $outputTupleIndices `,` + `operand_index` `=` $operandIndex `,` + `operand_tuple_indices` `=` $operandTupleIndices `>` + }]; } def StableHLO_ArgResultAlias : AttrDef { diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 6370cd06516..19bc973d4f5 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -4509,6 +4509,18 @@ static ParseResult parseDimsWithMinimumElements(AsmParser& parser, return success(); } +FailureOr> parseIntArray(AsmParser& parser) { + SmallVector ints; + if (failed(parseDims(parser, ints))) return failure(); + return ints; +} + +void printIntArray(AsmPrinter& printer, ArrayRef ints) { + printer << '['; + llvm::interleaveComma(ints, printer); + printer << ']'; +} + /// Parse a custom attribute that resembles a struct of the form /// < /// foo = something_parsed_by_custom_parser, diff --git a/stablehlo/dialect/StablehloOps.h b/stablehlo/dialect/StablehloOps.h index bb382a58eec..23d4992a0ef 100644 --- a/stablehlo/dialect/StablehloOps.h +++ b/stablehlo/dialect/StablehloOps.h @@ -183,6 +183,9 @@ ParseResult parseWindowAttributes(OpAsmParser &parser, DenseIntElementsAttr &rhsDilation, DenseElementsAttr &windowReversal); +// Print and parse IntArrays +FailureOr> parseIntArray(AsmParser &parser); +void printIntArray(AsmPrinter &printer, ArrayRef ints); } // end namespace stablehlo } // end namespace mlir