diff --git a/docs/ConstPropagationPass.md b/docs/ConstPropagationPass.md index 3d5669ed6f..e589be0c4e 100644 --- a/docs/ConstPropagationPass.md +++ b/docs/ConstPropagationPass.md @@ -27,16 +27,16 @@ func @foo() -> tensor<1xf32> { } ``` -## Remark +## Remark ONNXConstantOp uses MLIR DenseElementsAttr to store constant values. It is -important to note that, once a DenseElementsAttr is created, it is alive and +important to note that, once a DenseElementsAttr is created, it is alive and consumes memory until the end of compilation. In [Example](#example), all the three DenseElementsAttrs in the three ONNXConstantOps exist until the end of compilation. Especially, two intermediate DenseElementsAttrs in the two ONNXConstantOps produced by folding the two ONNXAddOps also exist. For a real world model, the number of intermediate DenseElementsAttrs will increase -quickly, which leads to a large memory footprint during compilation. +quickly, which leads to a large memory footprint during compilation. To avoid creating too many DenseElementsAttrs for intermediate ONNXConstantOps during `--constprop-onnx`, we designed a mechanism that dynamically allocates and @@ -93,7 +93,7 @@ def AddConstProp : Pat< // source patten: From add(lhs, rhs). (ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_), (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), - // result pattern: To c = lhs + rhs + // result pattern: To c = lhs + rhs (CreateAddOfTwoConst $addOp, $lhs, $rhs), // Additional constraints: if both lhs and rhs are dense constants. [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>; @@ -127,7 +127,7 @@ template Value ConstPropElementwiseBinary(PatternRewriter &rewriter, Value replacingValue, Value lhsValue, Value rhsValue) { ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue}); - Type replacingType = replacingValue.getType().cast(); + Type replacingType = mlir::cast(replacingValue.getType()); // Get lhs and rhs ElementsAttr from the values' defining constant ops. ElementsAttr lhs = getConstValueElements(lhsValue); diff --git a/docs/ImportONNXDefs.md b/docs/ImportONNXDefs.md index caf12beb8b..35f89c0e84 100644 --- a/docs/ImportONNXDefs.md +++ b/docs/ImportONNXDefs.md @@ -83,7 +83,7 @@ You will need to add the implementation code in the `src/Dialect/ONNX/ONNXOps.cp Tips: * Use `operandAdaptor` object to get the inputs (must use `operandAdaptor` to get the current values of the inputs) and the `op` object to get the attributes (can use `op` because attributes are typically immutable). * Use `hasShapeAndRank(X)` to test if `X` input is currently shaped and ranked. If not, return success as we will get a chance later to test the operation with this info. Note that some inputs may be scalar too, in which case they may or may not be encoded as a shape type. -* You can then use MLIR call `X.getType().cast()` to get a shape types, for which you can get the rank and the dimensions. At this time, we only check dimension validity for values known at runtime. Unknown dimensions are encoded as a negative number. Please only use the cast when you are sure that it will not assert, i.e. the type is indeed a `ShapedType`. +* You can then use MLIR call `mlir::cast(X.getType())` to get a shape types, for which you can get the rank and the dimensions. At this time, we only check dimension validity for values known at runtime. Unknown dimensions are encoded as a negative number. Please only use the cast when you are sure that it will not assert, i.e. the type is indeed a `ShapedType`. * When you find an error, report it with a friendly error message using `op->emitError(msg)`. ## Customize importer diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index 3751391236..41b9467b11 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -51,14 +51,14 @@ bool isCompatibleWithNNPALevel(std::string inputNNPALevel) { /// zAIU supports only F16, F32 and BFLOAT. Since MLIR does not support BFLOAT, /// we check F16 and F32 here only. zAIU only supports rank in range of (0, 4]. bool isValidElementTypeAndRank(Value val, bool donotCheckRank) { - if (val.getType().isa()) + if (mlir::isa(val.getType())) return true; - if (auto valueType = val.getType().dyn_cast_or_null()) { + if (auto valueType = mlir::dyn_cast_or_null(val.getType())) { Type elementType = (valueType) ? valueType.getElementType() : val.getType(); // Element type must be in 16 or F32. - if (elementType.isa() && - (elementType.cast().getWidth() == 16 || - elementType.cast().getWidth() == 32)) { + if (mlir::isa(elementType) && + (mlir::cast(elementType).getWidth() == 16 || + mlir::cast(elementType).getWidth() == 32)) { if (donotCheckRank) return true; // Rank must be in range of (0, 4]. @@ -78,8 +78,8 @@ bool checkLegalityPoolOpsCommon(POOLOP op, Value Y) { shapeHelper.computeShapeAndAssertOnFailure(); Value X = op.getX(); int64_t ceilMode = op.getCeilMode(); - ShapedType inputType = X.getType().cast(); - ShapedType outputType = Y.getType().cast(); + ShapedType inputType = mlir::cast(X.getType()); + ShapedType outputType = mlir::cast(Y.getType()); ArrayRef shapeInput = inputType.getShape(); ArrayRef shapeOutput = outputType.getShape(); @@ -248,14 +248,14 @@ bool meetPoolParamRestrictions(int64_t inputShape, int64_t kernelShape, if (outputShape != 1) return false; // padding_type must be VALID_PADDING. - if (!paddingType.equals("VALID_PADDING")) + if (!(paddingType == "VALID_PADDING")) return false; } else { // strides are greater than zero // kernel_width and kernel_height must be less than or equal to 64. if (kernelShape > 64) return false; - if (paddingType.equals("SAME_PADDING")) { + if (paddingType == "SAME_PADDING") { if (outputShape != ceil((float)inputShape / strides)) return false; } else { // VALID_PADDING @@ -412,7 +412,7 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getInput())) return false; - ShapedType inputType = op.getType().cast(); + ShapedType inputType = mlir::cast(op.getType()); if (!inputType.hasRank()) return false; int64_t rank = inputType.getRank(); @@ -429,7 +429,7 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getX())) return false; - ShapedType xType = op.getX().getType().cast(); + ShapedType xType = mlir::cast(op.getX().getType()); return xType.hasRank() && (xType.getRank() <= 4); } @@ -442,7 +442,7 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getInput())) return false; - ShapedType inputType = op.getType().cast(); + ShapedType inputType = mlir::cast(op.getType()); return inputType.hasRank() && (inputType.getRank() <= 4); } @@ -455,7 +455,7 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getX())) return false; - ShapedType xType = op.getX().getType().cast(); + ShapedType xType = mlir::cast(op.getX().getType()); return xType.hasRank() && (xType.getRank() <= 4); } @@ -468,7 +468,7 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getInput())) return false; - ShapedType inputType = op.getInput().getType().cast(); + ShapedType inputType = mlir::cast(op.getInput().getType()); return inputType.hasRank() && (inputType.getRank() <= 4); } @@ -481,7 +481,7 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getInput())) return false; - ShapedType inputType = op.getInput().getType().cast(); + ShapedType inputType = mlir::cast(op.getInput().getType()); return inputType.hasRank() && (inputType.getRank() <= 4); } @@ -500,8 +500,8 @@ bool isSuitableForZDNN( return false; if (!isValidElementTypeAndRank(op.getOperand(1))) return false; - ShapedType aType = op.getOperand(0).getType().cast(); - ShapedType bType = op.getOperand(1).getType().cast(); + ShapedType aType = mlir::cast(op.getOperand(0).getType()); + ShapedType bType = mlir::cast(op.getOperand(1).getType()); // Illegal if A or B is unranked. if (!aType.hasRank() || !bType.hasRank()) @@ -558,8 +558,8 @@ bool isSuitableForZDNN( if (!isValidElementTypeAndRank(C)) return false; - ShapedType aType = A.getType().cast(); - ShapedType bType = B.getType().cast(); + ShapedType aType = mlir::cast(A.getType()); + ShapedType bType = mlir::cast(B.getType()); ShapedType cType; ArrayRef aShape = aType.getShape(); ArrayRef bShape = bType.getShape(); @@ -567,7 +567,7 @@ bool isSuitableForZDNN( bool hasC = !isNoneValue(C); if (hasC) { - cType = C.getType().cast(); + cType = mlir::cast(C.getType()); cShape = cType.getShape(); } @@ -612,7 +612,7 @@ bool isSuitableForZDNN( std::optional axes = op.getAxes(); int64_t keepdims = op.getKeepdims(); - ShapedType dataType = op.getData().getType().cast(); + ShapedType dataType = mlir::cast(op.getData().getType()); auto shapeData = dataType.getShape(); // Check keepdims. @@ -623,8 +623,8 @@ bool isSuitableForZDNN( mlir::ArrayAttr axesVal = axes.value(); SmallVector axesAttrs(axesVal.begin(), axesVal.end()); if ((axesAttrs.size() != 2) || - (axesAttrs[0].dyn_cast().getInt() != 2) || - (axesAttrs[1].dyn_cast().getInt() != 3)) { + (mlir::dyn_cast(axesAttrs[0]).getInt() != 2) || + (mlir::dyn_cast(axesAttrs[1]).getInt() != 3)) { return false; } @@ -676,15 +676,15 @@ bool isSuitableForZDNN( if (!isValidElementTypeAndRank(B)) return false; - int64_t hidden_size = R.getType().cast().getShape()[2]; + int64_t hidden_size = mlir::cast(R.getType()).getShape()[2]; std::optional activations = op.getActivations(); // Check if direction and hidden_size in W have static dimensions. - ArrayRef wShape = W.getType().cast().getShape(); + ArrayRef wShape = mlir::cast(W.getType()).getShape(); if ((wShape[0] != 1 && wShape[0] != 2) || wShape[1] == ShapedType::kDynamic) return false; // Check if R has static dimensions, and the direction dim is 1 or 2. - ArrayRef rShape = R.getType().cast().getShape(); - if (!R.getType().cast().hasStaticShape() || + ArrayRef rShape = mlir::cast(R.getType()).getShape(); + if (!mlir::cast(R.getType()).hasStaticShape() || (rShape[0] != 1 && rShape[0] != 2)) return false; // Check hidden_size. @@ -694,11 +694,11 @@ bool isSuitableForZDNN( if (!isNoneValue(op.getSequenceLens())) return false; // check if B, initial_h and initial_c have static dimensions if given. - if (!isNoneValue(B) && !B.getType().cast().hasStaticShape()) + if (!isNoneValue(B) && !mlir::cast(B.getType()).hasStaticShape()) return false; // check if B's direction dim is 1 or 2. if (!isNoneValue(B)) { - ArrayRef bShape = B.getType().cast().getShape(); + ArrayRef bShape = mlir::cast(B.getType()).getShape(); if (bShape[0] != 1 && bShape[0] != 2) return false; } @@ -708,12 +708,14 @@ bool isSuitableForZDNN( return false; // zDNN support the default activations (["Sigmoid", "Tanh", "Tanh"]) only. if ((activations && (activations.value().size() > 0) && - (activations.value()[0].cast().getValue() != + (mlir::cast(activations.value()[0]).getValue() != "Sigmoid")) || (activations && (activations.value().size() > 1) && - (activations.value()[1].cast().getValue() != "Tanh")) || + (mlir::cast(activations.value()[1]).getValue() != + "Tanh")) || (activations && (activations.value().size() > 2) && - (activations.value()[2].cast().getValue() != "Tanh"))) + (mlir::cast(activations.value()[2]).getValue() != + "Tanh"))) return false; // zDNN does not support clip(Cell clip threshold). if (op.getClip()) @@ -755,24 +757,24 @@ bool isSuitableForZDNN( if (!isValidElementTypeAndRank(B)) return false; - int64_t hidden_size = R.getType().cast().getShape()[2]; + int64_t hidden_size = mlir::cast(R.getType()).getShape()[2]; std::optional activations = op.getActivations(); // Check if direction and hidden_size in W have static dimensions. - ArrayRef wShape = W.getType().cast().getShape(); + ArrayRef wShape = mlir::cast(W.getType()).getShape(); if ((wShape[0] != 1 && wShape[0] != 2) || wShape[1] == ShapedType::kDynamic) return false; // Check if R has static dimensions. - if (!R.getType().cast().hasStaticShape()) + if (!mlir::cast(R.getType()).hasStaticShape()) return false; // Check hidden_size. if (hidden_size > MAXIMUM_NUM_HIDDEN_SIZE_GRU) return false; // check if B and initial_h have static dimensions if given. - if (!isNoneValue(B) && !B.getType().cast().hasStaticShape()) + if (!isNoneValue(B) && !mlir::cast(B.getType()).hasStaticShape()) return false; // check if B's direction dim is 1 or 2. if (!isNoneValue(B)) { - ArrayRef bShape = B.getType().cast().getShape(); + ArrayRef bShape = mlir::cast(B.getType()).getShape(); if (bShape[0] != 1 && bShape[0] != 2) return false; } @@ -781,12 +783,14 @@ bool isSuitableForZDNN( return false; // zDNN support the default activations (["Sigmoid", "Tanh", "Tanh"]) only. if ((activations && (activations.value().size() > 0) && - (activations.value()[0].cast().getValue() != + (mlir::cast(activations.value()[0]).getValue() != "Sigmoid")) || (activations && (activations.value().size() > 1) && - (activations.value()[1].cast().getValue() != "Tanh")) || + (mlir::cast(activations.value()[1]).getValue() != + "Tanh")) || (activations && (activations.value().size() > 2) && - (activations.value()[2].cast().getValue() != "Tanh"))) + (mlir::cast(activations.value()[2]).getValue() != + "Tanh"))) return false; // zDNN does not support clip(Cell clip threshold). if (op.getClip()) @@ -859,7 +863,7 @@ static bool checkConv2DParamRestrictions(int64_t inputDim, int64_t kernelDim, int64_t stride, int64_t outputDim, StringRef paddingType) { if (stride == 0) { // paddingType must be VALID_PADDING. - if (!paddingType.equals("VALID_PADDING")) + if (!(paddingType == "VALID_PADDING")) return false; // inputDim must be = kernel dim. if (inputDim != kernelDim) @@ -875,7 +879,7 @@ static bool checkConv2DParamRestrictions(int64_t inputDim, int64_t kernelDim, // kernel dim must be less than or equal to 64. if (kernelDim > 64) return false; - if (paddingType.equals("SAME_PADDING")) { + if (paddingType == "SAME_PADDING") { // height_out restriction. if (outputDim != ceil((float)inputDim / stride)) return false; @@ -913,8 +917,8 @@ bool isSuitableForZDNN( ONNXConvOpShapeHelper shapeHelper(op.getOperation(), {}); shapeHelper.computeShapeAndAssertOnFailure(); - ShapedType inputType = op.getX().getType().cast(); - ShapedType outputType = op.getY().getType().cast(); + ShapedType inputType = mlir::cast(op.getX().getType()); + ShapedType outputType = mlir::cast(op.getY().getType()); ArrayRef shapeInput = inputType.getShape(); ArrayRef shapeOutput = outputType.getShape(); @@ -978,8 +982,8 @@ bool isSuitableForZDNN( template <> bool isSuitableForZDNN( ONNXBatchNormalizationInferenceModeOp op, const DimAnalysis *dimAnalysis) { - ShapedType inputType = op.getX().getType().cast(); - ShapedType outputType = op.getO_Y().getType().cast(); + ShapedType inputType = mlir::cast(op.getX().getType()); + ShapedType outputType = mlir::cast(op.getO_Y().getType()); ArrayRef shapeInput = inputType.getShape(); ArrayRef shapeOutput = outputType.getShape(); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 59d0e997e8..ad1dfda44d 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -46,7 +46,7 @@ ArrayAttr getLSTMGRUBiasSplitShape( Value getLSTMGRUZDNNWeightFromONNXWeight( Location loc, PatternRewriter &rewriter, Value weight, int isLSTM) { int64_t splitNum = isLSTM ? 4 : 3; - RankedTensorType weightType = weight.getType().cast(); + RankedTensorType weightType = mlir::cast(weight.getType()); Type elementType = weightType.getElementType(); ArrayRef weightShape = weightType.getShape(); int64_t direction = weightShape[0]; @@ -124,7 +124,7 @@ Value getLSTMGRUGetYh(Location loc, PatternRewriter &rewriter, Value val, if (isNoneValue(resYh) || isNoneValue(val)) return noneValue; - ArrayRef shapeX = X.getType().cast().getShape(); + ArrayRef shapeX = mlir::cast(X.getType()).getShape(); MultiDialectBuilder create(rewriter, loc); // Generate Y_h for onnx.LSTM from hn_output for all timestep Value minusOne = create.onnx.constantInt64({-1}); @@ -136,12 +136,12 @@ Value getLSTMGRUGetYh(Location loc, PatternRewriter &rewriter, Value val, Value intMax = create.onnx.constantInt64({INT_MAX}); StringRef directionStr = direction.getValue(); ArrayRef resYhShape = - resYh.getType().cast().getShape(); + mlir::cast(resYh.getType()).getShape(); int64_t T = isNoneValue(resY) ? 1 : shapeX[0]; int64_t D = resYhShape[0]; int64_t B = resYhShape[1]; int64_t H = resYhShape[2]; - Type elementType = resYh.getType().cast().getElementType(); + Type elementType = mlir::cast(resYh.getType()).getElementType(); Value axis = zero; Value step = one; Value ret; @@ -205,19 +205,20 @@ Value getLSTMGRUGetYc( SmallVector emitONNXSplitOp(Location loc, PatternRewriter &rewriter, Value input, IntegerAttr axis, ArrayAttr split) { - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); SmallVector outputTypes; int64_t splitNum = split.size(); ArrayRef inputShape = - input.getType().cast().getShape(); - int64_t splitAxis = axis.cast().getSInt(); + mlir::cast(input.getType()).getShape(); + int64_t splitAxis = mlir::cast(axis).getSInt(); assert(splitAxis >= 0 && "Negative axis"); for (int i = 0; i < splitNum; i++) { SmallVector outputShape; for (size_t dim = 0; dim < inputShape.size(); dim++) { - outputShape.emplace_back((dim == (unsigned int)splitAxis) - ? split[dim].cast().getInt() - : inputShape[dim]); + outputShape.emplace_back( + (dim == (unsigned int)splitAxis) + ? mlir::cast(split[dim]).getInt() + : inputShape[dim]); } outputTypes.emplace_back(RankedTensorType::get(outputShape, elementType)); } diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td index 7dd088bb3c..95d2756ea7 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td @@ -31,21 +31,21 @@ include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td" def IsEnableScalarBcastBinary: Constraint>; -def IsNoneType : Constraint())">>; +def IsNoneType : Constraint(($_self).getType())">>; -def IsNotNoneType : Constraint())">>; +def IsNotNoneType : Constraint($_self)">>; class HasRankOf : Constraint< - CPred<"$0.getType().isa() && " - "$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # rank> + CPred<"mlir::isa($0.getType()) && " + "mlir::cast($0.getType()).hasRank() && " + "mlir::cast($0.getType()).getRank() == " # rank> >; def IsBiasNoneOr1D : Constraint< - CPred<"$_self.getType().isa() || " - " ($_self.getType().isa() && " - " $_self.getType().cast().hasRank() && " - " $_self.getType().cast().getRank() == 1)"> + CPred<"mlir::isa($_self.getType()) || " + "mlir::isa($_self.getType()) && " + " mlir::cast($_self.getType()).hasRank() && " + " mlir::cast($_self.getType()).getRank() == 1)"> >; class VariadicSizeIs : Constraint< @@ -66,16 +66,16 @@ class GetNthVariadicOperand : NativeCodeCall<"$0[" # n # "]">; class getNthVariadicResults : NativeCodeCall; def GetShape : - NativeCodeCall<"$0.getType().cast().getShape()">; + NativeCodeCall<"mlir::cast($0.getType()).getShape()">; def GetDynShape: NativeCodeCall<"getDynShape($_loc, $_builder, $0)">; def GetRank : - NativeCodeCall<"$0.getType().cast().getRank()">; + NativeCodeCall<"mlir::cast($0.getType()).getRank()">; def GetSInt : - NativeCodeCall<"$0.cast().getSInt()">; + NativeCodeCall<"mlir::cast($0).getSInt()">; class GetStrAttr : NativeCodeCall<"$_builder.getStringAttr(\"" # s # "\")">; @@ -93,7 +93,7 @@ class GetI64ArrayAttr : NativeCodeCall<"$_builder.getI64ArrayAttr(" # n # ")">; def GetUnrankedTensorTypeOf : NativeCodeCall< - "UnrankedTensorType::get($0.getType().cast().getElementType())" + "UnrankedTensorType::get(mlir::cast($0.getType()).getElementType())" >; class EmitOp3 : @@ -135,7 +135,7 @@ def replaceONNXSigmoidPattern : Pat< (ONNXSigmoidOp $x), (ZHighUnstickOp (ZHighSigmoidOp (ZHighStickOp:$s_x $x, (NoneLayoutAttr)), (returnType $s_x))) ->; +>; //===----------------------------------------------------------------------===// // ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X), @@ -158,7 +158,7 @@ def replaceONNXAddPattern : Pat< def GetONNXSumOpWithoutFirst : NativeCodeCall< "$_builder.create(" " $_loc," - " $0[0].getType().cast()," + " mlir::cast($0[0].getType())," " ValueRange(OperandRange{$0.begin() + 1, $0.end()}))" >; @@ -290,7 +290,7 @@ def replaceONNXMaxPattern : Pat< (NoneLayoutAttr)), (returnType $s_x))) >; - + //===----------------------------------------------------------------------===// // ONNXSoftmaxOp %X = ONNXSqueezeOp // (ZHighUnstickOp @@ -343,7 +343,7 @@ def replaceONNXSoftmax3DPattern : Pat< // 0) //===----------------------------------------------------------------------===// def IsSoftmaxLegalForZDNN: Constraint< - CPred<"isSuitableForZDNN(" # + CPred<"isSuitableForZDNN(" # "dyn_cast_or_null($0.getDefiningOp()))">, "Softmax is legal for zDNN" >; @@ -381,7 +381,7 @@ def replaceONNXReduceMeanV13Pattern : Pat< >; //===----------------------------------------------------------------------===// -// ONNXMaxPoolSingleOutOp %X = +// ONNXMaxPoolSingleOutOp %X = // (ZHighUnstickOp // (ZHighMaxPoolOp // (ZHighStickOp %X), @@ -423,7 +423,7 @@ def replaceONNXMaxPoolSingleOutPattern : Pattern< >; //===----------------------------------------------------------------------===// -// ONNXAveragePoolOp %X = +// ONNXAveragePoolOp %X = // (ZHighUnstickOp // (ZHighAvgPoolOp // (ZHighStickOp %X), @@ -503,22 +503,22 @@ def replaceONNXMatMulPattern : Pat< //===----------------------------------------------------------------------===// def IsMatMulLegalForZDNN: Constraint< - CPred<"isSuitableForZDNN(" # + CPred<"isSuitableForZDNN(" # "dyn_cast_or_null($0.getDefiningOp()))">, "MatMul is legal for zDNN" >; // Be careful, this check is very specific to '$0' of rank 2 and '$1' of rank 1. def HaveSameLastDimR2R1: Constraint< - CPred<"(!$0.getType().cast().isDynamicDim(1))" # - " && (!$1.getType().cast().isDynamicDim(0))" # - " && ($0.getType().cast().getShape()[1]" # - " == $1.getType().cast().getShape()[0])">, + CPred<"(!mlir::cast($0.getType()).isDynamicDim(1))" # + " && (!mlir::cast($1.getType()).isDynamicDim(0))" # + " && (mlir::cast($0.getType()).getShape()[1]" # + " == mlir::cast($1.getType()).getShape()[0])">, "Have the same last dimension" >; -// Only 1D bias is suitable for this transformation since only then -// the semantics of bias addition is the same for both ONNX and zDNN. +// Only 1D bias is suitable for this transformation since only then +// the semantics of bias addition is the same for both ONNX and zDNN. def replaceONNXMatMulAddPattern1 : Pat< // From Add $b, (MatMul $x, $y) (ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)), @@ -552,14 +552,14 @@ def replaceONNXMatMulAddPattern2 : Pat< //===----------------------------------------------------------------------===// // GEMM //===----------------------------------------------------------------------===// -def IsTransposed: Constraint().getSInt() == 1)">>; +def IsTransposed: Constraint($_self).getSInt() == 1)">>; def Transpose2D: NativeCodeCall< "emitONNXTranspose($_loc, $_builder, $0, SmallVector({1, 0}))">; //===----------------------------------------------------------------------===// // NNPA does not directly support cases with alpha != 1.0 or beta != 1.0 or bias's rank == 2. -// Since alpha and beta were legalized by the pass configuration, we check rank only. +// Since alpha and beta were legalized by the pass configuration, we check rank only. // // If bias's rank is 1, directly lower ONNX GEMM to ZHigh MatMul. // ONNXGemmOp %A %B %C = ZHighUnstickOp @@ -672,7 +672,7 @@ def replaceONNXLSTMPattern1 : Pattern< [ // // Pre-process to prepare input_bias and hidden_bias. - // + // (EmitONNXSplitOp8Results:$splitB $b, (GetI64NAttr<1>), (GetLSTMGRUBiasSplitShape (GetShape $r))), (ZHighStickForLSTMOp:$input_bias $splitB__2, $splitB__0, $splitB__3, $splitB__1), (ZHighStickForLSTMOp:$hidden_bias $splitB__6, $splitB__4, $splitB__7, $splitB__5), @@ -828,12 +828,12 @@ def replaceONNXGRUPattern1 : Pattern< [ // // Pre-process to prepare input_bias and hidden_bias. - // + // (EmitONNXSplitOp6Results:$splitB $b, (GetI64NAttr<1>), (GetLSTMGRUBiasSplitShape (GetShape $r))), (ZHighStickForGRUOp:$input_bias $splitB__0, $splitB__1, $splitB__2), (ZHighStickForGRUOp:$hidden_bias $splitB__3, $splitB__4, $splitB__5), - + // // Main-process to lower ONNXGRUOp into ZHighGRUOp. // @@ -847,7 +847,7 @@ def replaceONNXGRUPattern1 : Pattern< $hidden_size, // hidden_size $direction, // direction (GetLSTMGRUReturnAllStepsAttr $res__0, $res__1)), // return_all_steps - + // // Post-process to generate two return values of ONNXGRUOp // from two return values of ZHighGRUOp(=$zHighGRU). @@ -917,7 +917,7 @@ def replaceONNXGRUPattern3 : Pattern< $hidden_size, // hidden_size $direction, // direction (GetLSTMGRUReturnAllStepsAttr $res__0, $res__1)), // return_all_steps - + // // Post-process to generate two return values of ONNXGRUOp // from two return values of ZHighGRUOp(=$zHighGRU). @@ -964,7 +964,7 @@ def replaceONNXGRUPattern4 : Pattern< // Rewrite // // ONNXConvOp %X, %W, %B -// +// // to // // (ZHighUnstickOp @@ -976,7 +976,7 @@ def replaceONNXGRUPattern4 : Pattern< // strides, // GetPaddingType, // ACT_NONE))) -// +// //===----------------------------------------------------------------------===// def GetStrAttrPaddingtypeConv: NativeCodeCall< @@ -998,7 +998,7 @@ def replaceONNXConv2DPattern : Pattern< (GetStrAttrPaddingtypeConv:$padtype $res), (GetI64ArrayAttrKernelShapeConv:$kernel_shape $res), (GetI64ArrayAttrStridesConv:$strides $res), - + (ZHighUnstickOp (ZHighConv2DOp (ZHighStickOp $x, (NHWCLayoutAttr)), @@ -1043,7 +1043,7 @@ def replaceONNXReluConvPattern : Pattern< (GetStrAttrPaddingtypeConv:$padtype $res), (GetI64ArrayAttrKernelShapeConv:$kernel_shape $res), (GetI64ArrayAttrStridesConv:$strides $res), - + (ZHighUnstickOp (ZHighConv2DOp (ZHighStickOp $x, (NHWCLayoutAttr)), diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index 2eef0b9646..fcb1a6b410 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -60,7 +60,7 @@ void addDynamicallyLegalOpFor(mlir::ConversionTarget *target, bool exceedLimit = llvm::any_of(genericOp->getOperands(), [](mlir::Value operand) { if (auto valueType = - operand.getType().dyn_cast()) { + mlir::dyn_cast(operand.getType())) { // Check if static dimension size exceeds zDNN limitations llvm::ArrayRef valueShape = valueType.getShape(); if (llvm::any_of(valueShape, [](int64_t dim) { diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp index 61157e1a96..d0acc5e2dd 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp @@ -61,7 +61,7 @@ inline int64_t summarizeHigherDims( void processDim(Value oper, int64_t &e4, int64_t &e3, int64_t &e2, int64_t &e1, std::string &msg) { // At this time, use only 1 of the two operands. - ShapedType operType = oper.getType().dyn_cast_or_null(); + ShapedType operType = mlir::dyn_cast_or_null(oper.getType()); assert(operType && operType.hasRank() && "expected shaped type with rank"); int64_t operRank = operType.getRank(); assert(operRank <= 4 && "expected rank <= 4"); @@ -117,7 +117,7 @@ void estimateTimeForMatMulOp(Operation *op, Value a, Value b, bool aTransposed, bool bTransposed, const DimAnalysis *dimAnalysis, double &cpuEstimatedTime, double &nnpaEstimatedTime) { // Scanning A. - ShapedType aType = a.getType().dyn_cast_or_null(); + ShapedType aType = mlir::dyn_cast_or_null(a.getType()); assert(aType && aType.hasRank() && "expected shaped type with A rank"); int64_t aRank = aType.getRank(); llvm::ArrayRef aShape = aType.getShape(); @@ -128,7 +128,7 @@ void estimateTimeForMatMulOp(Operation *op, Value a, Value b, bool aTransposed, int64_t aN = aShape[aNIndex]; int64_t aM = aShape[aMIndex]; // Scanning B. - ShapedType bType = b.getType().dyn_cast_or_null(); + ShapedType bType = mlir::dyn_cast_or_null(b.getType()); assert(bType && bType.hasRank() && "expected shaped type with B rank"); int64_t bRank = bType.getRank(); llvm::ArrayRef bShape = bType.getShape(); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp index 943c577181..9d751d246f 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp @@ -47,7 +47,7 @@ namespace onnx_mlir { /// A = scale / sqrt(var + epsilon) Value getSqrtResultBatchNormA( Location loc, PatternRewriter &rewriter, Value var, FloatAttr epsilon) { - Type elementType = var.getType().cast().getElementType(); + Type elementType = mlir::cast(var.getType()).getElementType(); MultiDialectBuilder create(rewriter, loc); // epsilon @@ -195,7 +195,8 @@ bool isDefinedByONNXConstantOp(Value v) { bool canInferencePadsForNNPAConv(ONNXConvOp op) { ONNXConvOpShapeHelper shapeHelper(op.getOperation(), {}); shapeHelper.computeShapeAndAssertOnFailure(); - RankedTensorType inputType = op.getX().getType().cast(); + RankedTensorType inputType = + mlir::cast(op.getX().getType()); ArrayRef inputShape = inputType.getShape(); // dimension of inferenced pads should be 4D if (shapeHelper.pads.size() != 4) @@ -242,9 +243,10 @@ DenseElementsAttr insertZerosForNonPaddedDims( int nElements = (nDims + extensionLength) * 2; SmallVector pads(nElements, 0); for (int i = 0; i < nDims; ++i) { - int64_t beginPad = origAttrs.getValue()[i].cast().getInt(); + int64_t beginPad = + mlir::cast(origAttrs.getValue()[i]).getInt(); int64_t endPad = - origAttrs.getValue()[nDims + i].cast().getInt(); + mlir::cast(origAttrs.getValue()[nDims + i]).getInt(); pads[i + extensionLength] = beginPad; pads[nDims + extensionLength + i + extensionLength] = endPad; } @@ -253,7 +255,8 @@ DenseElementsAttr insertZerosForNonPaddedDims( DenseElementsAttr createDenseFloatAttrOfValue( PatternRewriter &rewriter, Value origValue, float constantValue) { - Type elementType = origValue.getType().cast().getElementType(); + Type elementType = + mlir::cast(origValue.getType()).getElementType(); SmallVector wrapper(1, 0); wrapper[0] = constantValue; return DenseElementsAttr::get( @@ -271,13 +274,13 @@ ArrayAttr createArrayAttrOfZeros( // Create Type for Padded input Type CreatePaddedXType(Value x, ArrayAttr pads) { - RankedTensorType inputType = x.getType().cast(); + RankedTensorType inputType = mlir::cast(x.getType()); ArrayRef inputShape = inputType.getShape(); Type elementType = inputType.getElementType(); SmallVector paddingShape(4, 0); if (pads) { for (int i = 0; i < 4; i++) { - paddingShape[i] = pads.getValue()[i].cast().getInt(); + paddingShape[i] = mlir::cast(pads.getValue()[i]).getInt(); } } SmallVector paddedShape = {inputShape[0], inputShape[1], @@ -305,7 +308,7 @@ Type CreatePaddedXType(Value x, ArrayAttr pads) { /// ^ | A2 | | | | | | /// N-MDIS | | | v | | | | /// v +------------------------+ +-----------+-----------+-----+ -/// +/// /// Then, /// - for A1, do (A1 * B1), (A1 * B2), (A1 * B3), and concat the results to get (A1*B) /// - for A2, do (A2 * B1), (A2 * B2), (A2 * B3), and concat the results to get (A2*B) @@ -652,7 +655,8 @@ void getRewriteONNXForZHighDynamicallyLegal( if (!isCompatibleWithNNPALevel(NNPA_Z16)) return true; Value input = op.getInput(); - if (auto shapedType = input.getType().dyn_cast()) { + if (auto shapedType = + mlir::dyn_cast(input.getType())) { // Check element type. if (!isValidElementTypeAndRank(input, true)) return true; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td index 4e0570fa3d..b44e7098c4 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td @@ -79,14 +79,14 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern< // Rewrite `BinaryOp(lhs, rhs)` if one of the two inputs is a constant and // unidirectional broadcastable to the other input. // For example: lhs is a constant of shape [8] and rhs is of shape [1x4x8]. -// +// // Taking ONNXAddOp as an example, we rewrite it as follows: // // 1. `ONNXAddOp %constant, %X` will be canonicalized to `ONNXAddOp %X, %constant` // 2. `ONNXAddOp %X, %constant` will be replaced by // `ONNXAddOp %X, (ONNXExpandOp %constant, (ONNXShapeOp %X))` // -// +// //===----------------------------------------------------------------------===// @@ -97,7 +97,7 @@ def CreateShapeOp: NativeCodeCall< // Get a type for a tensor that stores the shape of another tensor. def GetShapeTypeOf: NativeCodeCall< - "RankedTensorType::get({$0.getType().cast().getRank()}, $_builder.getIntegerType(64))" + "RankedTensorType::get({mlir::cast($0.getType()).getRank()}, $_builder.getIntegerType(64))" >; // Check unidirectional broadcasting from the first to second tensor. @@ -286,7 +286,7 @@ class FloatAttrOfValue: // Check that a StrAttr does not contain a specific value. class IsNotStringAttrOfValue: - Constraint().getValue() != \"" # val # "\"">>; + Constraint($0).getValue() != \"" # val # "\"">>; // Check the a convolution operation is NOT leagal for zDNN def IsConvNotLegalForZDNN: Constraint< diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index 98e7235f1d..ebf7daff0b 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -187,7 +187,7 @@ Value insertAllocOrEmitZeroConstant(ArrayRef dims, ZTensorEncodingAttr::get(op->getContext(), layout)); ZMemRefType zMemRefType = convertZTensorToMemRefType(tensorType); MemRefType resType = - affine::normalizeMemRefType(zMemRefType.value.cast()); + affine::normalizeMemRefType(mlir::cast(zMemRefType.value)); // Create a ZHighStickifiedConstantOp. ZHighStickifiedConstantOp stickifiedConstant = @@ -240,9 +240,9 @@ Value insertShapeMemRefI64( /// Get the corresponding MemRefType and layout of a given ZTensorType. ZMemRefType convertZTensorToMemRefType(Type type) { ZMemRefType resZMemRefType; - if (type.isa()) { + if (mlir::isa(type)) { OpBuilder b(type.getContext()); - RankedTensorType tensorType = type.dyn_cast(); + RankedTensorType tensorType = mlir::dyn_cast(type); assert(tensorType && "expected only ranked shapes"); ArrayRef shape = tensorType.getShape(); Type elementType = tensorType.getElementType(); @@ -480,8 +480,8 @@ ZMemRefType convertZTensorToMemRefType(Type type) { } else { // Does not have tensorType.getEncoding(). resZMemRefType.value = MemRefType::get(shape, elementType); } - } else { // Not type.isa(). - resZMemRefType.value = type.dyn_cast(); + } else { // Not mlir::isa(type). + resZMemRefType.value = mlir::dyn_cast(type); } return resZMemRefType; } @@ -668,16 +668,15 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern { convertZTensorToMemRefType(*op->result_type_begin()); // Normalize MemRefType to get a static shape. - assert(zMemRefType.value.cast().getNumDynamicDims() == 0 && + assert(mlir::cast(zMemRefType.value).getNumDynamicDims() == 0 && "MemRefType has dynamic dimensions"); MemRefType normalizedType = - affine::normalizeMemRefType(zMemRefType.value.cast()); + affine::normalizeMemRefType(mlir::cast(zMemRefType.value)); ArrayRef normalizedShape = normalizedType.getShape(); // Get dense resource attribute. - auto blob = stickifiedConstOp.getValue() - .value() - .cast() + auto blob = mlir::cast( + stickifiedConstOp.getValue().value()) .getRawHandle() .getBlob(); assert(blob && "Expecting dense resource with a valid blob"); @@ -1030,7 +1029,7 @@ struct ZHighToZLowMatMulOpLowering : public ConversionPattern { // Prepare optional bias. Value bias = operandAdaptor.getB(); - if (bias.getType().isa()) { + if (mlir::isa(bias.getType())) { SmallVector resDims, biasDims; create.krnlIE.getShapeAsDims(alloc, resDims); ZTensorEncodingAttr::DataLayout biasLayout; @@ -1115,19 +1114,19 @@ struct ZHighToZLowLSTMOpLowering : public ConversionPattern { Value initial_c = operandAdaptor.getC0(); Value input_bias = operandAdaptor.getInputBias(); Value hidden_bias = operandAdaptor.getHiddenBias(); - if (initial_h.getType().isa()) { + if (mlir::isa(initial_h.getType())) { initial_h = insertAllocOrEmitZeroConstant(shapeHelper.hc0Shape, ZTensorEncodingAttr::DataLayout::_3DS, op, rewriter, loc); } - if (initial_c.getType().isa()) { + if (mlir::isa(initial_c.getType())) { initial_c = insertAllocOrEmitZeroConstant(shapeHelper.hc0Shape, ZTensorEncodingAttr::DataLayout::_3DS, op, rewriter, loc); } - if (input_bias.getType().isa()) { + if (mlir::isa(input_bias.getType())) { input_bias = insertAllocOrEmitZeroConstant(shapeHelper.biasShape, ZTensorEncodingAttr::DataLayout::FICO, op, rewriter, loc); } - if (hidden_bias.getType().isa()) { + if (mlir::isa(hidden_bias.getType())) { hidden_bias = insertAllocOrEmitZeroConstant(shapeHelper.biasShape, ZTensorEncodingAttr::DataLayout::FICO, op, rewriter, loc); } @@ -1196,15 +1195,15 @@ struct ZHighToZLowGRUOpLowering : public ConversionPattern { Value initial_h = operandAdaptor.getH0(); Value input_bias = operandAdaptor.getInputBias(); Value hidden_bias = operandAdaptor.getHiddenBias(); - if (initial_h.getType().isa()) { + if (mlir::isa(initial_h.getType())) { initial_h = insertAllocOrEmitZeroConstant(shapeHelper.h0Shape, ZTensorEncodingAttr::DataLayout::_3DS, op, rewriter, loc); } - if (input_bias.getType().isa()) { + if (mlir::isa(input_bias.getType())) { input_bias = insertAllocOrEmitZeroConstant(shapeHelper.biasShape, ZTensorEncodingAttr::DataLayout::ZRH, op, rewriter, loc); } - if (hidden_bias.getType().isa()) { + if (mlir::isa(hidden_bias.getType())) { hidden_bias = insertAllocOrEmitZeroConstant(shapeHelper.biasShape, ZTensorEncodingAttr::DataLayout::ZRH, op, rewriter, loc); } @@ -1250,9 +1249,8 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern { // create alloc ZHighFixGRUYOpShapeHelper shapeHelper(op, operands, &create.krnlIE); shapeHelper.computeShapeAndAssertOnFailure(); - MemRefType outputMemRefType = - typeConverter->convertType(op->getResults()[0].getType()) - .cast(); + MemRefType outputMemRefType = mlir::cast( + typeConverter->convertType(op->getResults()[0].getType())); // Value alloc = // create.mem.alignedAlloc(outputMemRefType, @@ -1377,9 +1375,8 @@ struct ZHighToZLowFixGRUYhOpLowering : public ConversionPattern { // create alloc ZHighFixGRUYhOpShapeHelper shapeHelper(op, operands, &create.krnlIE); shapeHelper.computeShapeAndAssertOnFailure(); - MemRefType outputMemRefType = - typeConverter->convertType(op->getResults()[0].getType()) - .cast(); + MemRefType outputMemRefType = mlir::cast( + typeConverter->convertType(op->getResults()[0].getType())); Value alloc = create.mem.alignedAlloc(outputMemRefType, shapeHelper.getOutputDims(0)); @@ -1451,7 +1448,7 @@ struct ZHighToZLowConv2DOpLowering : public ConversionPattern { // Prepare optional values: input_bias. Value bias = operandAdaptor.getInputBias(); - if (bias.getType().isa()) { + if (mlir::isa(bias.getType())) { // Bias's shape is [Channel_out]. SmallVector dims(1, shapeHelper.allOriginalDims[4]); bias = insertAllocOrEmitZeroConstant( @@ -1571,7 +1568,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering // // The following manual loop does a trick that puts `create.krnl.load` // inside the loop, and LLVM does not seem to read the f16 value. - uint64_t rank = res.getType().cast().getRank(); + uint64_t rank = mlir::cast(res.getType()).getRank(); ValueRange loopDef = create.krnl.defineLoops(rank); SmallVector lbs(rank, LiteralIndexExpr(0)); SmallVector ubs = shapeHelper.getOutputDims(); @@ -1633,7 +1630,7 @@ struct ZHighToZLowDataConversionLowering Type convertedType = this->typeConverter->convertType(outputTensorType); int64_t alignment = KrnlTypeConverter::getDefaultAllocAlignment(outputTensorType); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); // Types use the SIMD unrolling VL and VLHalf. @@ -1651,7 +1648,7 @@ struct ZHighToZLowDataConversionLowering // Alloc memory with padding for SIMD. Padding and loop unrolling use // unrollVL. - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); Value alloc = create.mem.alignedAllocWithSimdPadding( outputMemRefType, outputDims, unrollVL, alignment); diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp index b3145b3f29..a9cfd73a30 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp @@ -103,7 +103,7 @@ class ZLowStickLowering : public mlir::ConvertToLLVMPattern { ZLowStickOpAdaptor operandAdaptor(operands); // Do not get element type from adaptor since the type can be opaque. Type llvmElementTy = typeConverter->convertType( - stickOp.getX().getType().cast().getElementType()); + mlir::cast(stickOp.getX().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -157,10 +157,9 @@ class ZLowStickForLSTMLowering : public ConvertToLLVMPattern { ZLowStickForLSTMOp stickForLSTMOp = cast(op); ZLowStickForLSTMOpAdaptor operandAdaptor(operands); - Type llvmElementTy = typeConverter->convertType(stickForLSTMOp.getFGate() - .getType() - .cast() - .getElementType()); + Type llvmElementTy = typeConverter->convertType( + mlir::cast(stickForLSTMOp.getFGate().getType()) + .getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -245,7 +244,8 @@ class ZLowStickForGRULowering : public ConvertToLLVMPattern { ZLowStickForGRUOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - stickForGRUOp.getZGate().getType().cast().getElementType()); + mlir::cast(stickForGRUOp.getZGate().getType()) + .getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -329,7 +329,7 @@ class ZLowLSTMLowering : public ConvertToLLVMPattern { ZLowLSTMOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - lstmOp.getInput().getType().cast().getElementType()); + mlir::cast(lstmOp.getInput().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -525,7 +525,7 @@ class ZLowGRULowering : public ConvertToLLVMPattern { ZLowGRUOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - gruOp.getInput().getType().cast().getElementType()); + mlir::cast(gruOp.getInput().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -679,7 +679,7 @@ class ZLowUnstickLowering : public ConvertToLLVMPattern { ZLowUnstickOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - unstickOp.getOut().getType().cast().getElementType()); + mlir::cast(unstickOp.getOut().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -740,7 +740,7 @@ class ZLowUnaryElementwiseOpLowering : public ConvertToLLVMPattern { Value shape = operandAdaptor.getShape(); Value output = operandAdaptor.getOut(); Type llvmElementTy = typeConverter->convertType( - op->getOperand(0).getType().cast().getElementType()); + mlir::cast(op->getOperand(0).getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -818,7 +818,7 @@ class ZLowBinaryElementwiseOpLowering : public ConvertToLLVMPattern { Value shape = operandAdaptor.getShape(); Value output = operandAdaptor.getOut(); Type llvmElementTy = typeConverter->convertType( - op->getOperand(0).getType().cast().getElementType()); + mlir::cast(op->getOperand(0).getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -893,7 +893,7 @@ class ZLowSoftmaxOpLowering : public ConvertToLLVMPattern { ZLowSoftmaxOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - softmaxOp.getX().getType().cast().getElementType()); + mlir::cast(softmaxOp.getX().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -976,7 +976,7 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern { ZLowMatMulOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - matmulOp.getX().getType().cast().getElementType()); + mlir::cast(matmulOp.getX().getType()).getElementType()); bool stacked, broadcasting; if (matmulOp.getIsStacked() == -1) @@ -1114,7 +1114,7 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern { MultiDialectBuilder create(rewriter, loc); Type llvmElementTy = typeConverter->convertType( - convOp.getInput().getType().cast().getElementType()); + mlir::cast(convOp.getInput().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -1145,10 +1145,10 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern { convOp.getKernelShape().getValue(); // kernel height Value KH = create.llvm.constant(llvmI64Ty, - (int64_t)kernelShapeArrayAttr[0].cast().getInt()); + (int64_t)mlir::cast(kernelShapeArrayAttr[0]).getInt()); // kernel width Value KW = create.llvm.constant(llvmI64Ty, - (int64_t)kernelShapeArrayAttr[1].cast().getInt()); + (int64_t)mlir::cast(kernelShapeArrayAttr[1]).getInt()); // Get zDNN data type. zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); @@ -1187,10 +1187,10 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern { // Strides ArrayRef strideArrayAttr = convOp.getStrides().getValue(); - Value strideHeight = create.llvm.constant( - llvmI64Ty, (int64_t)strideArrayAttr[0].cast().getInt()); - Value strideWidth = create.llvm.constant( - llvmI64Ty, (int64_t)strideArrayAttr[1].cast().getInt()); + Value strideHeight = create.llvm.constant(llvmI64Ty, + (int64_t)mlir::cast(strideArrayAttr[0]).getInt()); + Value strideWidth = create.llvm.constant(llvmI64Ty, + (int64_t)mlir::cast(strideArrayAttr[1]).getInt()); // Activation function. Value actFunc; @@ -1264,7 +1264,7 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern { Value shape = operandAdaptor.getShape(); Value output = operandAdaptor.getOutput(); Type llvmElementTy = typeConverter->convertType( - op->getOperand(0).getType().cast().getElementType()); + mlir::cast(op->getOperand(0).getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -1293,10 +1293,10 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern { poolOp.getKernelShape().getValue(); // kernel height Value KH = create.llvm.constant(llvmI64Ty, - (int64_t)kernelShapeArrayAttr[0].cast().getInt()); + (int64_t)mlir::cast(kernelShapeArrayAttr[0]).getInt()); // kernel width Value KW = create.llvm.constant(llvmI64Ty, - (int64_t)kernelShapeArrayAttr[1].cast().getInt()); + (int64_t)mlir::cast(kernelShapeArrayAttr[1]).getInt()); // Get zDNN data type. zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); @@ -1321,10 +1321,10 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern { // Strides ArrayRef strideArrayAttr = poolOp.getStrides().getValue(); - Value strideHeight = create.llvm.constant( - llvmI64Ty, (int64_t)strideArrayAttr[0].cast().getInt()); - Value strideWidth = create.llvm.constant( - llvmI64Ty, (int64_t)strideArrayAttr[1].cast().getInt()); + Value strideHeight = create.llvm.constant(llvmI64Ty, + (int64_t)mlir::cast(strideArrayAttr[0]).getInt()); + Value strideWidth = create.llvm.constant(llvmI64Ty, + (int64_t)mlir::cast(strideArrayAttr[1]).getInt()); // Create zTensor for output. stickI8Ptr = zTensorHelper.getAlignedI8Ptr(output); @@ -1365,7 +1365,7 @@ class ZLowMeanReduce2DLowering : public ConvertToLLVMPattern { ZLowMeanReduce2DOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - meanOp.getInput().getType().cast().getElementType()); + mlir::cast(meanOp.getInput().getType()).getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -1433,7 +1433,8 @@ class ZLowBatchNormLowering : public ConvertToLLVMPattern { ZLowBatchNormOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( - batchnormOp.getInput().getType().cast().getElementType()); + mlir::cast(batchnormOp.getInput().getType()) + .getElementType()); ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp index b6ec195840..8e4cff4574 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp @@ -285,7 +285,7 @@ Value callApi(PatternRewriter &rewriter, Location loc, ModuleOp module, apiSpec.outputTy, apiSpec.inputTys, apiSpec.isVarArg); SmallVector outputTys; Type outputTy = apiSpec.outputTy; - if (!outputTy.isa()) + if (!mlir::isa(outputTy)) outputTys.emplace_back(outputTy); return create.llvm.call( ArrayRef(outputTys), symbolRef, ArrayRef(params)); @@ -305,7 +305,8 @@ size_t getRankFromMemRefType(LLVM::LLVMStructType memRefTy) { if (numElems == 3) return 0; // MemRef refers to a scalar. else - return memRefTy.getBody()[3].cast().getNumElements(); + return mlir::cast(memRefTy.getBody()[3]) + .getNumElements(); } /// Get a vector of 'size' dimensions from a 1D DenseElementsAttr. @@ -317,7 +318,7 @@ std::vector getDimsFromDenseElementsAttr(PatternRewriter &rewriter, std::vector dims; auto valueIt = valueAttr.getValues().begin(); for (unsigned int i = 0; i < size; ++i) { - int64_t dim = (*valueIt++).cast().getInt(); + int64_t dim = mlir::cast(*valueIt++).getInt(); Value dimVal = create.llvm.constant(rewriter.getI64Type(), dim); dims.emplace_back(dimVal); } @@ -367,7 +368,7 @@ std::vector getDimsFromShapeMemRefBySize(PatternRewriter &rewriter, module, addressOfOp.getGlobalNameAttr())); if (globalOp) { DenseElementsAttr valueAttr = - globalOp.getValue().value().dyn_cast(); + mlir::dyn_cast(globalOp.getValue().value()); if (valueAttr) return getDimsFromDenseElementsAttr( rewriter, loc, module, valueAttr, size); @@ -418,7 +419,7 @@ void getDimsFromMemRef(PatternRewriter &rewriter, Location loc, ModuleOp module, Value memRef, SmallVectorImpl &dims) { MemRefDescriptor memRefDesc(memRef); size_t rank = - getRankFromMemRefType(memRef.getType().cast()); + getRankFromMemRefType(mlir::cast(memRef.getType())); for (size_t i = 0; i < rank; i++) { Value dimI64 = memRefDesc.size(rewriter, loc, i); dims.emplace_back(dimI64); @@ -428,9 +429,9 @@ void getDimsFromMemRef(PatternRewriter &rewriter, Location loc, ModuleOp module, /// Type conversion from LLVMType to zDNNType. /// TODO: fill in the complete list of the zDNN types. zdnn_data_types llvmTypeToZDNNType(Type elemType) { - if (elemType.isa()) + if (mlir::isa(elemType)) return FP16; - else if (elemType.isa()) + else if (mlir::isa(elemType)) return FP32; else llvm_unreachable("Unexpected LLVM type, cannot be converted to zDNN type."); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td index 60fdbd1972..1cc7d84e47 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td @@ -28,12 +28,12 @@ def ZHigh_Dialect : Dialect { let name = "zhigh"; let summary = "A high-level dialect for the ONNX NNPA accelerator ISA."; let cppNamespace = "::onnx_mlir::zhigh"; - let useDefaultAttributePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; let usePropertiesForAttributes = 0; } //===----------------------------------------------------------------------===// -// ZHigh Attribute +// ZHigh Attribute //===----------------------------------------------------------------------===// // All of the Tensor attributes will extend this class. @@ -83,12 +83,12 @@ def ZTensorEncodingAttr : ZHigh_Attr<"ZTensorEncoding"> { // Whether a ztensor type has the specified layout. class DataLayoutOfPred : And<[ - CPred<"($_self.cast<::mlir::RankedTensorType>()) &&" - "($_self.cast<::mlir::RankedTensorType>()." - "getEncoding().dyn_cast_or_null()) &&" - "($_self.cast<::mlir::RankedTensorType>()." - "getEncoding().cast().getDataLayout()" - " == ZTensorEncodingAttr::DataLayout::" # layout # ")"> + CPred<"(mlir::cast<::mlir::RankedTensorType>($_self)) &&" + "(mlir::cast<::mlir::RankedTensorType>($_self)." + "mlir::dyn_cast_or_null(getEncoding())) &&" + "(mlir::cast<::mlir::RankedTensorType>($_self)." + "mlir::cast(getEncoding()).getDataLayout()" + " == ZTensorEncodingAttr::DataLayout::" # layout # ")"> ]>; // So far ZTensor supports only F16 for stickified data. @@ -157,7 +157,7 @@ def ZHighStickOp:ZHigh_Op<"Stick", [Pure, ]; let hasCanonicalizer = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighStickOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighStickOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighStickOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -182,7 +182,7 @@ def ZHighUnstickOp:ZHigh_Op<"Unstick", [Pure, ]; let hasCanonicalizer = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighUnstickOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighUnstickOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnstickOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -205,7 +205,7 @@ def ZHighF32ToDLF16Op:ZHigh_Op<"F32ToDLF16", [Pure, ]; let hasCanonicalizer = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighF32ToDLF16Op::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighF32ToDLF16Op::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -228,7 +228,7 @@ def ZHighDLF16ToF32Op:ZHigh_Op<"DLF16ToF32", [Pure, ]; let hasCanonicalizer = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighDLF16ToF32Op::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighDLF16ToF32Op::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -249,7 +249,7 @@ def ZHighAddOp:ZHigh_Op<"Add", [Pure, SameOperandsAndResultLayout, AnyTypeOf<[AnyZTensor]>:$Y); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighAddOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighAddOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighBinaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -270,7 +270,7 @@ def ZHighSubOp:ZHigh_Op<"Sub", [Pure, SameOperandsAndResultLayout, AnyTypeOf<[AnyZTensor]>:$Y); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighSubOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighSubOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighBinaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -291,7 +291,7 @@ def ZHighMulOp:ZHigh_Op<"Mul", [Pure, SameOperandsAndResultLayout, AnyTypeOf<[AnyZTensor]>:$Y); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighMulOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighMulOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighBinaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -312,7 +312,7 @@ def ZHighDivOp:ZHigh_Op<"Div", [Pure, SameOperandsAndResultLayout, AnyTypeOf<[AnyZTensor]>:$Y); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighDivOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighDivOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighBinaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -333,7 +333,7 @@ def ZHighMinOp:ZHigh_Op<"Min", [Pure, SameOperandsAndResultLayout, AnyTypeOf<[AnyZTensor]>:$Y); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighMinOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighMinOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighBinaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -354,7 +354,7 @@ def ZHighMaxOp:ZHigh_Op<"Max", [Pure, SameOperandsAndResultLayout, AnyTypeOf<[AnyZTensor]>:$Y); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighMaxOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighMaxOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighBinaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -373,7 +373,7 @@ def ZHighLogOp:ZHigh_Op<"Log", [Pure, SameOperandsAndResultLayout, let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighLogOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighLogOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -392,7 +392,7 @@ def ZHighExpOp:ZHigh_Op<"Exp", [Pure, SameOperandsAndResultLayout, let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighExpOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighExpOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -411,7 +411,7 @@ def ZHighReluOp:ZHigh_Op<"Relu", [Pure, SameOperandsAndResultLayout, let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighReluOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighReluOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -430,7 +430,7 @@ def ZHighTanhOp:ZHigh_Op<"Tanh", [Pure, SameOperandsAndResultLayout, let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighTanhOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighTanhOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -449,7 +449,7 @@ def ZHighSigmoidOp:ZHigh_Op<"Sigmoid", [Pure, SameOperandsAndResultLayout, let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighSigmoidOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighSigmoidOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -470,7 +470,7 @@ def ZHighSoftmaxOp:ZHigh_Op<"Softmax", [Pure, SameOperandsAndResultLayout, DefaultValuedStrAttr:$act_func); let results = (outs ZTensor_3DS:$Out); let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighSoftmaxOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighSoftmaxOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -484,8 +484,8 @@ def ZHighMeanReduce2DOp:ZHigh_Op<"MeanReduce2d", [Pure, DeclareOpInterfaceMethods]> { let summary = "ZHigh 2D mean reduce operation"; let description = [{ - ZHigh operation to perform 2D mean reduce. Given an input 4D tensor, - returns a downsampled tensor reducing the middle 2nd and 3rd dimensions + ZHigh operation to perform 2D mean reduce. Given an input 4D tensor, + returns a downsampled tensor reducing the middle 2nd and 3rd dimensions to a size of 1 based on the mean of the original values. Input and Output tensors should be in the 3D layout. }]; @@ -493,13 +493,13 @@ def ZHighMeanReduce2DOp:ZHigh_Op<"MeanReduce2d", [Pure, let results = (outs ZTensor_NHWC:$output); let builders = [ OpBuilder<(ins "::mlir::Value":$input), [{ - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, input); }]> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighMeanReduce2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighMeanReduce2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighMeanReduce2DOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -524,13 +524,13 @@ def ZHighMaxPool2DOp:ZHigh_Op<"MaxPool2D", [Pure, let builders = [ OpBuilder<(ins "::mlir::Value":$input, "::mlir::ArrayAttr":$kernel_shape, "::mlir::ArrayAttr":$strides, "::mlir::StringAttr":$padding_type), [{ - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, input, kernel_shape, strides, padding_type); }]> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighMaxPool2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighMaxPool2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighPoolingOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -555,13 +555,13 @@ def ZHighAvgPool2DOp:ZHigh_Op<"AvgPool2D", [Pure, let builders = [ OpBuilder<(ins "::mlir::Value":$input, "::mlir::ArrayAttr":$kernel_shape, "::mlir::ArrayAttr":$strides, "::mlir::StringAttr":$padding_type), [{ - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, input, kernel_shape, strides, padding_type); }]> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighAvgPool2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighAvgPool2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighPoolingOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -584,14 +584,14 @@ def ZHighMatMulOp:ZHigh_Op<"MatMul", [Pure, let results = (outs AnyTypeOf<[ZTensor_2D, ZTensor_3DS]>:$Out); let builders = [ OpBuilder<(ins "::mlir::Value":$X, "::mlir::Value":$Y, "::mlir::Value":$B), [{ - Type elementType = X.getType().cast().getElementType(); + Type elementType = mlir::cast(X.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, X, Y, B); }]> ]; let hasVerifier = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighMatMulOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighMatMulOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighMatMulOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -610,11 +610,11 @@ def ZHighLSTMOp:ZHigh_Op<"LSTM", [Pure, * Shape for input_weights is `[D, I, 4*H]`. * Shape for hidden_weights is `[D, H, 4*H]`. * Shape for input_bias and hidden_bias is `[D, 4*H]`. - * Shape for hn_output is `[S, D, B, H]` if return all timesteps + * Shape for hn_output is `[S, D, B, H]` if return all timesteps and `[1, D, B, H]` if return the final step only. * Shape for cf_output is `[1, D, B, H]`. - * S is timesteps, D is the number of directions (1 for unidirectional and - * 2 for bidirectional), B is batch size, I is input size, and + * S is timesteps, D is the number of directions (1 for unidirectional and + * 2 for bidirectional), B is batch size, I is input size, and * H is hidden size. * direction accepts "forward", "reverse", or "bidirectional * return_all_steps: -1 returns all timesteps, 0: returns only the last timestep. @@ -637,7 +637,7 @@ def ZHighLSTMOp:ZHigh_Op<"LSTM", [Pure, "::mlir::Value":$hidden_weights, "::mlir::Value":$hidden_bias, "::mlir::IntegerAttr":$hidden_size, "::mlir::StringAttr":$direction, "::mlir::IntegerAttr":$return_all_steps), [{ - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, resType, input, h0, c0, input_weights, input_bias, hidden_weights, @@ -646,7 +646,7 @@ def ZHighLSTMOp:ZHigh_Op<"LSTM", [Pure, ]; let hasVerifier = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighLSTMOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighLSTMOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighLSTMOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -665,10 +665,10 @@ def ZHighGRUOp:ZHigh_Op<"GRU", [Pure, * Shape for input_weights is `[D, I, 3*H]`. * Shape for hidden_weights is `[D, H, 3*H]`. * Shape for input_bias and hidden_bias is `[D, 3*H]`. - * Shape for hn_output is `[S, D, B, H]` if return all timesteps + * Shape for hn_output is `[S, D, B, H]` if return all timesteps and `[1, D, B, H]` if return the final step only. - * S is timesteps, D is the number of directions (1 for unidirectional and - * 2 for bidirectional), B is batch size, I is input size, and + * S is timesteps, D is the number of directions (1 for unidirectional and + * 2 for bidirectional), B is batch size, I is input size, and * H is hidden size. * direction accepts "forward", "reverse", or "bidirectional * return_all_steps: -1 returns all timesteps, 0: returns only the last timestep." @@ -688,7 +688,7 @@ def ZHighGRUOp:ZHigh_Op<"GRU", [Pure, "::mlir::Value":$input_bias, "::mlir::Value":$hidden_weights, "::mlir::Value":$hidden_bias, "::mlir::IntegerAttr":$hidden_size, "::mlir::StringAttr":$direction, "::mlir::IntegerAttr":$return_all_steps), [{ - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, input, h0, input_weights, input_bias, hidden_weights, @@ -697,7 +697,7 @@ def ZHighGRUOp:ZHigh_Op<"GRU", [Pure, ]; let hasVerifier = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighGRUOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighGRUOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighGRUOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -712,9 +712,9 @@ def ZHighStickForLSTMOp:ZHigh_Op<"StickForLSTM", [Pure, let summary = "ZHigh stick operation for LSTM"; let description = [{ ZHigh operation to perform a stick for LSTM. - Variadic: list of pointers for input data to be transformed: - - LSTM concatenated: 4 data pointers, one for each input gate in - Forget, Input, Cell, Output (FICO) order, + Variadic: list of pointers for input data to be transformed: + - LSTM concatenated: 4 data pointers, one for each input gate in + Forget, Input, Cell, Output (FICO) order, }]; let arguments = (ins TensorOf<[F32]>:$f_gate, TensorOf<[F32]>:$i_gate, @@ -730,7 +730,7 @@ def ZHighStickForLSTMOp:ZHigh_Op<"StickForLSTM", [Pure, }]> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighStickForLSTMOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighStickForLSTMOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighStickForLSTMOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -745,7 +745,7 @@ def ZHighStickForGRUOp:ZHigh_Op<"StickForGRU", [Pure, let summary = "ZHigh stick operation for GRU"; let description = [{ ZHigh operation to perform a stick for GRU. - Variadic: list of pointers for input data to be transformed: + Variadic: list of pointers for input data to be transformed: - GRU concatenated: 3 data pointers, one for each input gate in (Z)update, Reset, Hidden, (ZRH) gate order }]; @@ -761,7 +761,7 @@ def ZHighStickForGRUOp:ZHigh_Op<"StickForGRU", [Pure, }]> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighStickForGRUOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighStickForGRUOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighStickForGRUOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -775,14 +775,14 @@ def ZHighConv2DOp:ZHigh_Op<"Conv2D", [Pure, DeclareOpInterfaceMethods]> { let summary = "ZHigh 2D convolution operation"; let description = [{ - ZHigh operation to perform 2D convolution. + ZHigh operation to perform 2D convolution. * input: `[num_batches, height_in, width_in, channels_in]` - * input_kernel: `[kernel_height, kernel_width, channels_in, channels_out]` + * input_kernel: `[kernel_height, kernel_width, channels_in, channels_out]` * input_bias: `[channels_out] ` - * kernel_shape: 1D array of kernel height and width - * strides: 1D array of stride height and width - * padding_type: SAME_PADDING or VALID_PADDING - * act_func: ACT_NONE or ACT_RELU + * kernel_shape: 1D array of kernel height and width + * strides: 1D array of stride height and width + * padding_type: SAME_PADDING or VALID_PADDING + * act_func: ACT_NONE or ACT_RELU * output: `[num_batches, height_out, width_out, channels_out]` }]; let arguments = (ins ZTensor_NHWC:$input, @@ -798,7 +798,7 @@ def ZHighConv2DOp:ZHigh_Op<"Conv2D", [Pure, OpBuilder<(ins "::mlir::Value":$input, "::mlir::Value":$input_kernel, "::mlir::Value":$input_bias, "::mlir::ArrayAttr":$kernel_shape, "::mlir::ArrayAttr":$strides, "::mlir::StringAttr":$padding_type, "::mlir::StringAttr":$act_func), [{ - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); build($_builder, $_state, resType, input, input_kernel, input_bias, kernel_shape, strides, padding_type, act_func); @@ -806,7 +806,7 @@ def ZHighConv2DOp:ZHigh_Op<"Conv2D", [Pure, ]; let hasVerifier = 1; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighConv2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighConv2DOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighConv2DOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -833,7 +833,7 @@ def ZHighBatchNormOp:ZHigh_Op<"BatchNorm", [Pure, }]> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighBatchNormOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighBatchNormOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); @@ -861,7 +861,7 @@ def ZHighStickifiedConstantOfShapeOp:ZHigh_Op<"StickifiedConstantOfShape", [Pure let summary = "ZHigh Stickified Constant operation for a dynamic shape"; let description = [{ This operator produces a constant tensor to store stickified data. - The stickified data is defined by a f32 scalar value, a dynamic shape + The stickified data is defined by a f32 scalar value, a dynamic shape and a layout. Stickified data is 4K-aligned. }]; let arguments = (ins TensorOf<[I64]>:$shape, @@ -873,7 +873,7 @@ def ZHighStickifiedConstantOfShapeOp:ZHigh_Op<"StickifiedConstantOfShape", [Pure "::mlir::StringAttr":$layout)> ]; let extraClassDefinition = [{ - onnx_mlir::ONNXOpShapeHelper * ZHighStickifiedConstantOfShapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::ONNXOpShapeHelper * ZHighStickifiedConstantOfShapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { onnx_mlir::ONNXOpShapeHelper *sh = new ZHighStickifiedConstantOfShapeOpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp index 2a4271779c..8e70fe364f 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp @@ -73,20 +73,20 @@ namespace zhigh { std::vector getZHighAuxSplitResultType( Value input, int64_t axis, ArrayAttr split) { - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); std::vector outputTypes; if (split.size() == 0) { llvm_unreachable("Unsupported split (size==0)"); } else { ArrayRef inputShape = - input.getType().cast().getShape(); + mlir::cast(input.getType()).getShape(); int64_t splitNum = split.size(); for (int i = 0; i < splitNum; i++) { SmallVector outputShape; for (unsigned int dim = 0; dim < inputShape.size(); dim++) { - outputShape.emplace_back((dim == axis) - ? split[dim].cast().getInt() - : inputShape[dim]); + outputShape.emplace_back( + (dim == axis) ? mlir::cast(split[dim]).getInt() + : inputShape[dim]); } outputTypes.emplace_back(RankedTensorType::get(outputShape, elementType)); } @@ -114,7 +114,7 @@ Attribute ZTensorEncodingAttr::parse(AsmParser &parser, Type type) { // Process the data from the parsed dictionary value into struct-like data. for (const NamedAttribute &attr : dict) { if (attr.getName() == "dataLayout") { - StringAttr layoutAttr = attr.getValue().dyn_cast(); + StringAttr layoutAttr = mlir::dyn_cast(attr.getValue()); if (!layoutAttr) { parser.emitError( parser.getNameLoc(), "expected a string value for data layout"); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp index 8f25d842eb..a05e73890e 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp @@ -107,9 +107,12 @@ LogicalResult ZHighConv2DOp::verify() { return failure(); // Verify bias shape. - if (!B.getType().isa() && hasRankedType(B) && hasRankedType(K)) { - int64_t channelOutB = B.getType().cast().getShape()[0]; - int64_t channelOutK = K.getType().cast().getShape()[3]; + if (!mlir::isa(B.getType()) && hasRankedType(B) && + hasRankedType(K)) { + int64_t channelOutB = + mlir::cast(B.getType()).getShape()[0]; + int64_t channelOutK = + mlir::cast(K.getType()).getShape()[3]; if (!ShapedType::isDynamic(channelOutB) && !ShapedType::isDynamic(channelOutK) && (channelOutB != channelOutK)) return failure(); @@ -117,11 +120,11 @@ LogicalResult ZHighConv2DOp::verify() { // Verify kernel shape. ArrayAttr kernelShape = getKernelShape(); - int64_t attrKH = kernelShape[0].cast().getInt(); - int64_t attrKW = kernelShape[1].cast().getInt(); + int64_t attrKH = mlir::cast(kernelShape[0]).getInt(); + int64_t attrKW = mlir::cast(kernelShape[1]).getInt(); if (hasRankedType(K)) { - int64_t KH = K.getType().cast().getShape()[0]; - int64_t KW = K.getType().cast().getShape()[1]; + int64_t KH = mlir::cast(K.getType()).getShape()[0]; + int64_t KW = mlir::cast(K.getType()).getShape()[1]; if (!ShapedType::isDynamic(KH) && KH != attrKH) return failure(); if (!ShapedType::isDynamic(KW) && KW != attrKW) @@ -140,7 +143,8 @@ LogicalResult ZHighConv2DOp::inferShapes( if (!hasRankedType(getInput()) || !hasRankedType(getInputKernel())) return success(); - RankedTensorType inputType = getInput().getType().cast(); + RankedTensorType inputType = + mlir::cast(getInput().getType()); ZHighConv2DOpShapeHelper shapeHelper(getOperation()); return shapeHelper.computeShapeAndUpdateType( inputType.getElementType(), inputType.getEncoding()); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/ZHighDLF16ToF32.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/ZHighDLF16ToF32.td index 6fd89e55f4..5e24b965e4 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/ZHighDLF16ToF32.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/ZHighDLF16ToF32.td @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#ifndef DLF16_TO_F32_TD -#define DLF16_TO_F32_TD +#ifndef DLF16_TO_F32_TD +#define DLF16_TO_F32_TD #ifndef OP_BASE include "src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td" @@ -31,11 +31,11 @@ include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td" /// >; def GetTypeInDLF16: NativeCodeCall< - "RankedTensorType::get($0.getType().cast().getShape(), $_builder.getF16Type())" + "RankedTensorType::get(mlir::cast($0.getType()).getShape(), $_builder.getF16Type())" >; //===----------------------------------------------------------------------===// -// DRR patterns +// DRR patterns //===----------------------------------------------------------------------===// // zhigh.DLF16ToF32 (zhigh.F32ToDLF16(%X)) = %X diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp index d3939e01a7..6c2cebe5c9 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp @@ -100,15 +100,15 @@ LogicalResult ZHighGRUOp::verify() { // Verify hidden size in W. if (hasRankedType(W)) { - int64_t dim2 = W.getType().cast().getShape()[2]; + int64_t dim2 = mlir::cast(W.getType()).getShape()[2]; if (!ShapedType::isDynamic(dim2) && (dim2 != hiddenSize * 3)) return failure(); } // Verify hidden size in R. if (hasRankedType(R)) { - int64_t dim1 = R.getType().cast().getShape()[1]; - int64_t dim2 = R.getType().cast().getShape()[2]; + int64_t dim1 = mlir::cast(R.getType()).getShape()[1]; + int64_t dim2 = mlir::cast(R.getType()).getShape()[2]; if (!ShapedType::isDynamic(dim1) && (dim1 != hiddenSize)) return failure(); if (!ShapedType::isDynamic(dim2) && (dim2 != hiddenSize * 3)) @@ -116,15 +116,15 @@ LogicalResult ZHighGRUOp::verify() { } // Verify hidden size in WB. - if (!WB.getType().isa() && hasRankedType(WB)) { - int64_t dim1 = WB.getType().cast().getShape()[1]; + if (!mlir::isa(WB.getType()) && hasRankedType(WB)) { + int64_t dim1 = mlir::cast(WB.getType()).getShape()[1]; if (!ShapedType::isDynamic(dim1) && (dim1 != hiddenSize * 3)) return failure(); } // Verify hidden size in RB. - if (!RB.getType().isa() && hasRankedType(RB)) { - int64_t dim1 = RB.getType().cast().getShape()[1]; + if (!mlir::isa(RB.getType()) && hasRankedType(RB)) { + int64_t dim1 = mlir::cast(RB.getType()).getShape()[1]; if (!ShapedType::isDynamic(dim1) && (dim1 != hiddenSize * 3)) return failure(); } @@ -141,7 +141,8 @@ LogicalResult ZHighGRUOp::inferShapes( if (!hasRankedType(getInput()) || !hasRankedType(getHiddenWeights())) return success(); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); ZTensorEncodingAttr encoding = ZTensorEncodingAttr::get( this->getContext(), ZTensorEncodingAttr::DataLayout::_4DS); ZHighGRUOpShapeHelper shapeHelper(getOperation()); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp index a0effef16d..41648352e0 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp @@ -102,15 +102,15 @@ LogicalResult ZHighLSTMOp::verify() { // Verify hidden size in W. if (hasRankedType(W)) { - int64_t dim2 = W.getType().cast().getShape()[2]; + int64_t dim2 = mlir::cast(W.getType()).getShape()[2]; if (!ShapedType::isDynamic(dim2) && (dim2 != hiddenSize * 4)) return failure(); } // Verify hidden size in R. if (hasRankedType(R)) { - int64_t dim1 = R.getType().cast().getShape()[1]; - int64_t dim2 = R.getType().cast().getShape()[2]; + int64_t dim1 = mlir::cast(R.getType()).getShape()[1]; + int64_t dim2 = mlir::cast(R.getType()).getShape()[2]; if (!ShapedType::isDynamic(dim1) && (dim1 != hiddenSize)) return failure(); if (!ShapedType::isDynamic(dim2) && (dim2 != hiddenSize * 4)) @@ -118,15 +118,15 @@ LogicalResult ZHighLSTMOp::verify() { } // Verify hidden size in WB. - if (!WB.getType().isa() && hasRankedType(WB)) { - int64_t dim1 = WB.getType().cast().getShape()[1]; + if (!mlir::isa(WB.getType()) && hasRankedType(WB)) { + int64_t dim1 = mlir::cast(WB.getType()).getShape()[1]; if (!ShapedType::isDynamic(dim1) && (dim1 != hiddenSize * 4)) return failure(); } // Verify hidden size in RB. - if (!RB.getType().isa() && hasRankedType(RB)) { - int64_t dim1 = RB.getType().cast().getShape()[1]; + if (!mlir::isa(RB.getType()) && hasRankedType(RB)) { + int64_t dim1 = mlir::cast(RB.getType()).getShape()[1]; if (!ShapedType::isDynamic(dim1) && (dim1 != hiddenSize * 4)) return failure(); } @@ -150,7 +150,8 @@ LogicalResult ZHighLSTMOp::inferShapes( SmallVector hnOutputDims, cfOutputDims; IndexExpr::getShape(shapeHelper.getOutputDims(0), hnOutputDims); IndexExpr::getShape(shapeHelper.getOutputDims(1), cfOutputDims); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ZTensorEncodingAttr encoding = ZTensorEncodingAttr::get( this->getContext(), ZTensorEncodingAttr::DataLayout::_4DS); updateType( diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp index 20b95b424b..6d556ea36f 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp @@ -103,7 +103,8 @@ LogicalResult ZHighMatMulOp::inferShapes( SmallVector outputDims; IndexExpr::getShape(shapeHelper.getOutputDims(), outputDims); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); ZTensorEncodingAttr encoding; if (outputDims.size() == 2) encoding = ZTensorEncodingAttr::get( @@ -128,7 +129,7 @@ LogicalResult ZHighMatMulOp::verify() { ZTensorEncodingAttr::DataLayout yLayout = getZTensorLayout(Y.getType()); // Bias can be None. ZTensorEncodingAttr::DataLayout bLayout; - bool hasBias = !B.getType().isa(); + bool hasBias = !mlir::isa(B.getType()); if (hasBias) bLayout = getZTensorLayout(B.getType()); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp index 0e335269c5..5153fbd6bf 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp @@ -56,7 +56,7 @@ LogicalResult ZHighMeanReduce2DOp::inferShapes( if (!hasRankedType(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); ZHighMeanReduce2DOpShapeHelper shapeHelper(getOperation()); return shapeHelper.computeShapeAndUpdateType( inputType.getElementType(), inputType.getEncoding()); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp index a094193793..f56e4e2684 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp @@ -26,7 +26,7 @@ namespace zhigh { /// Check if a value type is ranked or unranked. bool hasRankedType(Value val) { - ShapedType shapedType = val.getType().cast(); + ShapedType shapedType = mlir::cast(val.getType()); return (shapedType && shapedType.hasRank()); } @@ -143,15 +143,15 @@ StringAttr convertZTensorDataLayoutToStringAttr( // Utility functions to query ztensor information. bool isZTensor(Type type) { - if (auto ttp = type.dyn_cast()) - if (ttp.getEncoding().dyn_cast_or_null()) + if (auto ttp = mlir::dyn_cast(type)) + if (mlir::dyn_cast_or_null(ttp.getEncoding())) return true; return false; } ZTensorEncodingAttr getZTensorEncoding(Type type) { - if (auto ttp = type.dyn_cast()) - return ttp.getEncoding().dyn_cast_or_null(); + if (auto ttp = mlir::dyn_cast(type)) + return mlir::dyn_cast_or_null(ttp.getEncoding()); return nullptr; } @@ -173,24 +173,24 @@ StringAttr getZTensorLayoutAttr(OpBuilder &builder, Type type) { Value getMinusBcastConst( mlir::OpBuilder &builder, Location loc, FloatAttr floatAttr, Value X) { - ShapedType xType = X.getType().cast(); + ShapedType xType = mlir::cast(X.getType()); assert(xType.hasStaticShape() && "expected static shape"); float val = floatAttr.getValueAsDouble() * -1.0; DenseElementsAttr denseAttr = - DenseElementsAttr::get(X.getType().cast(), val); + DenseElementsAttr::get(mlir::cast(X.getType()), val); MultiDialectBuilder create(builder, loc); return create.onnx.constant(denseAttr); } Value getConstantOfType( OpBuilder &builder, Location loc, Type type, float val) { - ShapedType shapedType = type.cast(); + ShapedType shapedType = mlir::cast(type); assert(shapedType.hasStaticShape() && "expected static shape"); Type elementType = shapedType.getElementType(); DenseElementsAttr denseAttr; - if (elementType.isa()) + if (mlir::isa(elementType)) denseAttr = DenseElementsAttr::get(shapedType, (int64_t)val); - else if (elementType.isa()) + else if (mlir::isa(elementType)) denseAttr = DenseElementsAttr::get(shapedType, val); else llvm_unreachable("Unsupport type"); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td index 6f579ad925..8a005652fa 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#ifndef OP_HELPER -#define OP_HELPER +#ifndef OP_HELPER +#define OP_HELPER #ifndef OP_BASE include "src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td" @@ -37,7 +37,7 @@ def NotSameLayout: Constraint< "Two ztensors have different layouts" >; -def IsNoneType : Constraint())">>; +def IsNoneType : Constraint(($_self).getType())">>; def GetLayout : NativeCodeCall< "::onnx_mlir::zhigh::convertZTensorDataLayoutToStringAttr($_builder, " @@ -77,7 +77,7 @@ class GetConstantOfType : NativeCodeCall< def IsStaticShapeTensor: Constraint< CPred< - "$0.getType().cast<::mlir::ShapedType>().hasStaticShape()">, + "mlir::cast<::mlir::ShapedType>($0.getType()).hasStaticShape()">, "is a tensor of static shape">; def IsPlusConstantFloat : Constraint< @@ -195,4 +195,4 @@ def GetAxisNHWC : NativeCodeCall< "::onnx_mlir::zhigh::getAxisNHWC($0)" >; -#endif // OP_HELPER +#endif // OP_HELPER diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp index fc46e5c58e..5f9a11a5ce 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp @@ -98,7 +98,8 @@ LogicalResult ZHighMaxPool2DOp::inferShapes( if (!hasRankedType(getInput())) return success(); - RankedTensorType inputType = getInput().getType().cast(); + RankedTensorType inputType = + mlir::cast(getInput().getType()); ZHighPoolingOpShapeHelper shapeHelper(getOperation()); return shapeHelper.computeShapeAndUpdateType( inputType.getElementType(), inputType.getEncoding()); @@ -113,7 +114,8 @@ LogicalResult ZHighAvgPool2DOp::inferShapes( if (!hasRankedType(getInput())) return success(); - RankedTensorType inputType = getInput().getType().cast(); + RankedTensorType inputType = + mlir::cast(getInput().getType()); ZHighPoolingOpShapeHelper shapeHelper(getOperation()); return shapeHelper.computeShapeAndUpdateType( inputType.getElementType(), inputType.getEncoding()); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp index ce98dc640b..830fdfbf3a 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp @@ -32,8 +32,8 @@ void ZHighStickOp::build( OpBuilder &builder, OperationState &state, Value input, StringAttr layout) { Type resType = builder.getNoneType(); Type resElementType = builder.getF16Type(); - if (!input.getType().isa()) { - ShapedType inputType = input.getType().cast(); + if (!mlir::isa(input.getType())) { + ShapedType inputType = mlir::cast(input.getType()); int64_t rank = -1; if (inputType.hasRank()) { rank = inputType.getRank(); @@ -111,7 +111,7 @@ LogicalResult ZHighStickOp::inferShapes( if (isa(input.getType()) || !hasRankedType(input)) return success(); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); StringAttr layout = getLayoutAttr(); int64_t rank = inputType.getRank(); @@ -123,7 +123,8 @@ LogicalResult ZHighStickOp::inferShapes( auto encoding = ZTensorEncodingAttr::get(this->getContext(), dataLayout); ZHighStickOpShapeHelper shapeHelper(getOperation()); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); return shapeHelper.computeShapeAndUpdateType(elementType, encoding); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp index f43fc46462..911d343c02 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp @@ -55,7 +55,8 @@ LogicalResult ZHighStickForGRUOp::inferShapes( !hasRankedType(getHGate())) return success(); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); ZTensorEncodingAttr encoding = ZTensorEncodingAttr::get( this->getContext(), ZTensorEncodingAttr::DataLayout::ZRH); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp index 572e0a6a7f..8f1b4a07a1 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp @@ -55,7 +55,8 @@ LogicalResult ZHighStickForLSTMOp::inferShapes( !hasRankedType(getCGate()) && !hasRankedType(getOGate())) return success(); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); ZTensorEncodingAttr encoding = ZTensorEncodingAttr::get( this->getContext(), ZTensorEncodingAttr::DataLayout::FICO); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp index 00c7df06df..c46f97dd79 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp @@ -26,7 +26,7 @@ namespace zhigh { void ZHighStickifiedConstantOfShapeOp::build(OpBuilder &builder, OperationState &state, Value shape, FloatAttr value, StringAttr layout) { Type resType = builder.getNoneType(); - ShapedType shapeType = shape.getType().cast(); + ShapedType shapeType = mlir::cast(shape.getType()); Type elementType = builder.getF16Type(); if (shapeType.hasRank()) { @@ -61,7 +61,7 @@ LogicalResult ZHighStickifiedConstantOfShapeOpShapeHelper::computeShape() { if (!hasRankedType(shape)) return success(); - auto shapeType = shape.getType().cast(); + auto shapeType = mlir::cast(shape.getType()); int64_t rank = shapeType.getShape()[0]; // Output dims of result. @@ -97,7 +97,7 @@ LogicalResult ZHighStickifiedConstantOfShapeOp::inferShapes( if (!hasRankedType(shape)) return success(); - auto shapeType = shape.getType().cast(); + auto shapeType = mlir::cast(shape.getType()); StringAttr layout = getLayoutAttr(); int64_t rank = shapeType.getShape()[0]; @@ -109,7 +109,8 @@ LogicalResult ZHighStickifiedConstantOfShapeOp::inferShapes( auto encoding = ZTensorEncodingAttr::get(this->getContext(), dataLayout); ZHighStickifiedConstantOfShapeOpShapeHelper shapeHelper(getOperation()); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); return shapeHelper.computeShapeAndUpdateType(elementType, encoding); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp index b2c1d4d0bb..77152ba81c 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp @@ -32,7 +32,7 @@ void ZHighUnstickOp::build( OpBuilder &builder, OperationState &state, Value input) { Type resType; Type resElementType = builder.getF32Type(); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); if (hasRankedType(input)) { // Compute shape. ArrayRef inputShape = inputType.getShape(); @@ -105,7 +105,8 @@ LogicalResult ZHighUnstickOp::inferShapes( return success(); ZHighUnstickOpShapeHelper shapeHelper(getOperation()); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 1a3fd03dfd..8af7b50978 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -136,9 +136,8 @@ void NNPAAccelerator::registerPasses(int optLevel) const { mlir::MemRefType NNPAAccelerator::convertTensorTypeToMemRefType( const mlir::TensorType tensorType) const { assert(tensorType.hasRank() && "expected only ranked shapes"); - if (tensorType.cast() - .getEncoding() - .dyn_cast_or_null()) { + if (mlir::dyn_cast_or_null( + mlir::cast(tensorType).getEncoding())) { onnx_mlir::zhigh::ZMemRefType zMemRefType = onnx_mlir::zhigh::convertZTensorToMemRefType(tensorType); return zMemRefType.value; @@ -149,9 +148,8 @@ mlir::MemRefType NNPAAccelerator::convertTensorTypeToMemRefType( int64_t NNPAAccelerator::getDefaultAllocAlignment( const mlir::TensorType tensorType) const { assert(tensorType.hasRank() && "expected only ranked shapes"); - if (tensorType.cast() - .getEncoding() - .dyn_cast_or_null()) + if (mlir::dyn_cast_or_null( + mlir::cast(tensorType).getEncoding())) return gAlignment; return -1; } diff --git a/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp b/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp index ada3f02b06..475b2f03b7 100644 --- a/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp +++ b/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp @@ -50,8 +50,8 @@ static int constantFoldStdAllocID = 0; /// Get a constant value from a ConstantOp. static int64_t getConstantValue(arith::ConstantOp constOp) { - if (IntegerAttr attr = constOp.getValue().dyn_cast()) { - int64_t val = attr.cast().getInt(); + if (IntegerAttr attr = mlir::dyn_cast(constOp.getValue())) { + int64_t val = mlir::cast(attr).getInt(); return val; } else { llvm_unreachable("Only support IntegerAttr"); @@ -69,7 +69,7 @@ class FoldStdAlloc : public OpRewritePattern { Location loc = allocOp.getLoc(); Value memRef = allocOp.getResult(); - MemRefType memRefType = memRef.getType().dyn_cast(); + MemRefType memRefType = mlir::dyn_cast(memRef.getType()); Type elementType = memRefType.getElementType(); // 1. Match diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp index 2e080c2c4c..9be86865ed 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp @@ -42,7 +42,7 @@ namespace { /// modification. bool valueFromZTensor(Value tensor) { // Function arguments are always CPU tensors. - if (tensor.dyn_cast()) + if (mlir::dyn_cast(tensor)) return false; Operation *op = tensor.getDefiningOp(); diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp index fec2ef8b6f..a32bacb4c4 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp @@ -39,7 +39,7 @@ static void getRawData(DenseElementsAttr denseAttr, std::vector &data) { if (!denseAttr.isSplat()) { data = denseAttr.getRawData(); } else { - ShapedType denseShapeType = denseAttr.getType().cast(); + ShapedType denseShapeType = mlir::cast(denseAttr.getType()); std::vector rawData = denseAttr.getRawData(); int64_t numElements = denseShapeType.getNumElements(); for (int i = 0; i < numElements; i++) @@ -49,8 +49,8 @@ static void getRawData(DenseElementsAttr denseAttr, std::vector &data) { /// MLIR type to zDNN type. zdnn_data_types mlirTypeToZDNNType(Type elementType) { - if (elementType.isa()) { - FloatType floatTy = elementType.cast(); + if (mlir::isa(elementType)) { + FloatType floatTy = mlir::cast(elementType); if (floatTy.getWidth() == 16) { return FP16; } else if (floatTy.getWidth() == 32) { @@ -91,13 +91,13 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter, Value replacingValue, Value input, StringAttr layout) { Location loc = replacingValue.getLoc(); Operation *op = input.getDefiningOp(); - ArrayRef shape = input.getType().cast().getShape(); - Type elementType = input.getType().cast().getElementType(); + ArrayRef shape = mlir::cast(input.getType()).getShape(); + Type elementType = mlir::cast(input.getType()).getElementType(); int rank = shape.size(); // Read dense attributes. - DenseElementsAttr dataAttr = op->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); + DenseElementsAttr dataAttr = mlir::dyn_cast_or_null( + op->getAttrOfType<::mlir::Attribute>("value")); assert(dataAttr && "Attribute is null"); // Read attributes's raw data. std::vector rawData; @@ -141,23 +141,20 @@ ZHighStickifiedConstantOp createConstantForStickForLSTM( Operation *cOp = inputC.getDefiningOp(); Operation *oOp = inputO.getDefiningOp(); - ArrayRef fShape = inputF.getType().cast().getShape(); + ArrayRef fShape = + mlir::cast(inputF.getType()).getShape(); assert((fShape.size() == 2 || fShape.size() == 3) && "Wrong tensor shape"); - Type elementType = inputF.getType().cast().getElementType(); + Type elementType = mlir::cast(inputF.getType()).getElementType(); // Read dense attributes. - DenseElementsAttr fDataAttr = - fOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); - DenseElementsAttr iDataAttr = - iOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); - DenseElementsAttr cDataAttr = - cOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); - DenseElementsAttr oDataAttr = - oOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); + DenseElementsAttr fDataAttr = mlir::dyn_cast_or_null( + fOp->getAttrOfType<::mlir::Attribute>("value")); + DenseElementsAttr iDataAttr = mlir::dyn_cast_or_null( + iOp->getAttrOfType<::mlir::Attribute>("value")); + DenseElementsAttr cDataAttr = mlir::dyn_cast_or_null( + cOp->getAttrOfType<::mlir::Attribute>("value")); + DenseElementsAttr oDataAttr = mlir::dyn_cast_or_null( + oOp->getAttrOfType<::mlir::Attribute>("value")); assert((fDataAttr && iDataAttr && cDataAttr && oDataAttr) && "Attribute is null"); // Read attributes's raw data. @@ -205,20 +202,18 @@ ZHighStickifiedConstantOp createConstantForStickForGRU( Operation *rOp = inputR.getDefiningOp(); Operation *hOp = inputH.getDefiningOp(); - ArrayRef zShape = inputZ.getType().cast().getShape(); + ArrayRef zShape = + mlir::cast(inputZ.getType()).getShape(); assert((zShape.size() == 2 || zShape.size() == 3) && "Wrong tensor shape"); - Type elementType = inputZ.getType().cast().getElementType(); + Type elementType = mlir::cast(inputZ.getType()).getElementType(); // Read dense attributes. - DenseElementsAttr zDataAttr = - zOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); - DenseElementsAttr rDataAttr = - rOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); - DenseElementsAttr hDataAttr = - hOp->getAttrOfType<::mlir::Attribute>("value") - .dyn_cast_or_null(); + DenseElementsAttr zDataAttr = mlir::dyn_cast_or_null( + zOp->getAttrOfType<::mlir::Attribute>("value")); + DenseElementsAttr rDataAttr = mlir::dyn_cast_or_null( + rOp->getAttrOfType<::mlir::Attribute>("value")); + DenseElementsAttr hDataAttr = mlir::dyn_cast_or_null( + hOp->getAttrOfType<::mlir::Attribute>("value")); assert((zDataAttr && rDataAttr && hDataAttr) && "Attribute is null"); // Read attributes's raw data. std::vector rawZData, rawHData, rawRData, rawOData; diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp index 107f6d4abb..b0fd597da5 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp @@ -43,7 +43,8 @@ std::pair areProducedByUnstickOpSameLayout( PatternRewriter &rewriter, ValueRange values) { // Check the first value and get its layout. Value first = values[0]; - if (first.isa() || !isa(first.getDefiningOp())) + if (mlir::isa(first) || + !isa(first.getDefiningOp())) return std::make_pair(false, nullptr); Value firstStickifiedVal = cast(first.getDefiningOp()).getIn(); @@ -53,7 +54,7 @@ std::pair areProducedByUnstickOpSameLayout( // Check all values. bool allTheSame = llvm::all_of(values, [&](Value v) { using namespace onnx_mlir::zhigh; - if (v.isa() || !isa(v.getDefiningOp())) + if (mlir::isa(v) || !isa(v.getDefiningOp())) return false; Value stickifiedVal = cast(v.getDefiningOp()).getIn(); StringAttr nextLayout = convertZTensorDataLayoutToStringAttr( @@ -121,7 +122,7 @@ class ONNXUnaryOpLayoutPropPattern : public OpRewritePattern { Value output = unaryOp.getY(); // Input is a block argument, do nothing. - if (input.dyn_cast()) + if (mlir::dyn_cast(input)) return failure(); // Input is a CPU tensor, do nothing. @@ -176,7 +177,7 @@ class ONNXBinaryOpLayoutPropPattern : public OpRewritePattern { Value output = binaryOp.getC(); // Input is a block argument, do nothing. - if (A.dyn_cast() || B.dyn_cast()) + if (mlir::dyn_cast(A) || mlir::dyn_cast(B)) return failure(); // Input is a CPU tensor, do nothing. @@ -283,9 +284,9 @@ class ONNXConcatLayoutPropagatePattern : public OpRewritePattern { // for padding. // TODO: get this info from affine_map that is used for stickiyfing NHWC. return llvm::all_of(values, [&layoutAttr](Value v) { - if (v.getType().isa() && - v.getType().cast().hasRank()) { - ArrayRef dims = v.getType().cast().getShape(); + if (mlir::isa(v.getType()) && + mlir::cast(v.getType()).hasRank()) { + ArrayRef dims = mlir::cast(v.getType()).getShape(); if (isNHWCLayout(layoutAttr)) // Value is NCHW that will be directly unstickified from NHWC. // NCHW, C is at 1. diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp index f4a8a4b2fc..f97373764a 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp @@ -75,7 +75,7 @@ class UnstickStickRemovalPattern : public OpRewritePattern { std::optional stickLayout = stickOp.getLayout(); // Input is a block argument, ignore it. - if (stickInput.dyn_cast()) + if (mlir::dyn_cast(stickInput)) return failure(); // Get UnstickOp that produced the stick input. @@ -141,7 +141,7 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { return failure(); // Input is a block argument, ignore it. - if (stickInput.dyn_cast()) + if (mlir::dyn_cast(stickInput)) return failure(); // Input must have no affine layout. In other words, it has been normalized. @@ -182,8 +182,9 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { // Match shapes. Value stickRes = stickOp.getOut(); Value unstickInput = unstickOp.getX(); - MemRefType stickResType = stickRes.getType().dyn_cast(); - MemRefType unstickInputType = unstickInput.getType().dyn_cast(); + MemRefType stickResType = mlir::dyn_cast(stickRes.getType()); + MemRefType unstickInputType = + mlir::dyn_cast(unstickInput.getType()); if (!stickResType.hasStaticShape() || (stickResType.getShape() != unstickInputType.getShape())) return failure(); @@ -213,7 +214,7 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { /// /// * Example: /// -/// Consider the following code: +/// Consider the following code: /// ```mlir /// zlow.unstick(%stick, %A) {layout = "2D"}: memref<2x3xf16, #map2D>, memref<2x3xf32> /// affine.for @@ -240,7 +241,7 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern { /// as Transpose, Concat, and Split. /// /// * Why does this rewriting work? -/// +/// /// - This rewriting depends on the fact that `zlow.stick` and `zlow.unstick` /// maintain an affine map that maps one element in a memref to an element in /// another memref. Those maps are `#map2D` and `#map3D` in the above example. @@ -294,9 +295,9 @@ class UnstickLoadStoreStickRemovalPattern // Common types. Type stickifiedElementType = - stickifiedMemRef.getType().cast().getElementType(); + mlir::cast(stickifiedMemRef.getType()).getElementType(); Type cpuElementType = - cpuMemRef.getType().cast().getElementType(); + mlir::cast(cpuMemRef.getType()).getElementType(); // Stickified Memref must have affine layout to access elements. if (!hasNonIdentityLayout(stickifiedMemRef.getType())) @@ -558,7 +559,7 @@ class UnstickLoadStoreStickRemovalPattern Value storeValue = storeOp.getValue(); // Store's input must be defined by a memref.alloc. - if (destMemref.isa()) + if (mlir::isa(destMemref)) return false; Operation *allocOp = destMemref.getDefiningOp(); if (!isa(allocOp)) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index e39fb91e95..af73413a0e 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -255,7 +255,7 @@ class FrontendGenImpl { static onnx::TypeProto fromMlirToONNXType(Type mlirType) { onnx::TypeProto onnxType; - if (mlirType.isa()) { + if (mlir::isa(mlirType)) { // Done: Uninitialized TypeProto onnxType represents NoneType. } else if (auto mlirTensorType = dyn_cast(mlirType)) { onnx::TypeProto::Tensor &onnxTensorType = *onnxType.mutable_tensor_type(); @@ -340,9 +340,10 @@ class FrontendGenImpl { assert(elem_type.value_case() == onnx::TypeProto::kTensorType && "expect tensor inside sequence type"); Type mlir_elem_type = ImportTensorType(elem_type, dim_params); - if (!mlir_elem_type.isa()) + if (!mlir::isa(mlir_elem_type)) llvm_unreachable("Seq type is incorrect"); - Type seq_type = mlir::SeqType::get(mlir_elem_type.cast(), -1); + Type seq_type = + mlir::SeqType::get(mlir::cast(mlir_elem_type), -1); return seq_type; } llvm_unreachable("unexpected type"); @@ -786,8 +787,9 @@ class FrontendGenImpl { if (j < outputMap.size() && outputMap[j] >= MAX_NUM_TYPES) { // Mapping gives a connection with an input. Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType(); - if (inputType.isa()) { - Type elementType = inputType.cast().getElementType(); + if (mlir::isa(inputType)) { + Type elementType = + mlir::cast(inputType).getElementType(); auto outType = UnrankedTensorType::get(elementType); outputTypes.emplace_back(outType); } else { @@ -888,7 +890,7 @@ class FrontendGenImpl { getNodeInputs(node, inputs); auto attributes = ImportNodeAttributes(node); std::vector outputTypes; - auto inputType = inputs[0].getType().cast(); + auto inputType = mlir::cast(inputs[0].getType()); if (inputType.getElementType().isInteger(64)) { outputTypes.emplace_back( mlir::ONNXStringType::get(builder_.getContext())); @@ -1032,7 +1034,7 @@ class FrontendGenImpl { std::vector inputs; getNodeInputs(node, inputs); Type elementType = - inputs[0].getType().cast().getElementType(); + mlir::cast(inputs[0].getType()).getElementType(); llvm::SmallVector values( 1, builder_.getZeroAttr(elementType)); @@ -1074,7 +1076,7 @@ class FrontendGenImpl { const Type elementType = builder_.getIntegerType(64); const auto attributes = ImportNodeAttributes(node); for (auto attr : attributes) { - if (auto arrayAttr = attr.getValue().dyn_cast()) { + if (auto arrayAttr = mlir::dyn_cast(attr.getValue())) { const auto tensorType = RankedTensorType::get({(int64_t)arrayAttr.size()}, elementType); auto constantDenseAttribute = @@ -1449,7 +1451,7 @@ class FrontendGenImpl { SmallVector argAttrs; for (size_t k = 0; k < funcAttrsToMove.size(); ++k) { if (i < funcAttrsToMove[k].size()) { - auto name = (funcAttrsToMove[k].getValue()[i]).cast(); + auto name = mlir::cast(funcAttrsToMove[k].getValue()[i]); if (name) { NamedAttribute namedAttr = builder_.getNamedAttr(argAttrNames[k], name); diff --git a/src/Builder/ModelInputShaper.cpp b/src/Builder/ModelInputShaper.cpp index 8fb5441151..067ab67896 100644 --- a/src/Builder/ModelInputShaper.cpp +++ b/src/Builder/ModelInputShaper.cpp @@ -109,7 +109,7 @@ RankedTensorType forceShape( } // namespace Type ModelInputShaper::reshape(int inputIndex, Type inputType) const { - if (auto rankedTensorTy = inputType.dyn_cast()) { + if (auto rankedTensorTy = mlir::dyn_cast(inputType)) { ArrayRef origDims = rankedTensorTy.getShape(); // Update the input dimensions based on internal information. if (force_dim_dynamic_enabled_) { diff --git a/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp b/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp index 227e4a5dbb..10f982864d 100644 --- a/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp +++ b/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp @@ -46,7 +46,7 @@ class KrnlSeqAllocOpLowering : public ConversionPattern { MultiDialectBuilder create(rewriter, loc); Value outputSeq = thisOp.getResult(); - auto outputType = outputSeq.getType().cast(); + auto outputType = mlir::cast(outputSeq.getType()); Value alloc; if (outputType.isDynamicDim(0)) { llvm::SmallVector length(operandAdaptor.getLength()); diff --git a/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp b/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp index 75c02b8d5a..d09f020c54 100644 --- a/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp +++ b/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp @@ -58,10 +58,10 @@ class KrnlSeqExtractOpLowering : public ConversionPattern { rewriter.replaceOp(op, output); return success(); } else { - if (!output.getType().isa()) + if (!mlir::isa(output.getType())) llvm_unreachable( "Not implemented: type of onnx seq element is not tensor"); - auto outputType = output.getType().cast(); + auto outputType = mlir::cast(output.getType()); SmallVector allocParams; for (size_t i = 0; i < outputType.getShape().size(); i++) { if (outputType.isDynamicDim(i)) { diff --git a/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp b/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp index 1a9068f4b1..cb58f47a1e 100644 --- a/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp +++ b/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp @@ -44,7 +44,8 @@ class KrnlSeqStoreOpLowering : public ConversionPattern { MultiDialectBuilder create(rewriter, loc); // Allocate a new tensor and copy input tensor into it - auto inputType = operandAdaptor.getInput().getType().cast(); + auto inputType = + mlir::cast(operandAdaptor.getInput().getType()); SmallVector allocParams; for (size_t i = 0; i < inputType.getShape().size(); i++) { if (inputType.isDynamicDim(i)) { @@ -56,8 +57,8 @@ class KrnlSeqStoreOpLowering : public ConversionPattern { // Cast the input tensor to the element type of the sequence auto seq = operandAdaptor.getSeq(); - auto seqElementType = - seq.getType().cast().getElementType().cast(); + auto seqElementType = mlir::cast( + mlir::cast(seq.getType()).getElementType()); auto casted = create.mem.cast(alloc, seqElementType); // Store the tensor diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index 1ef6b0467a..abcd008004 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -391,8 +391,8 @@ static void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &builder, for (int boundType = 0; boundType < 2; boundType++) { auto &operands = boundType == 0 ? lbOperands : ubOperands; auto &map = boundType == 0 ? lbMap : ubMap; - map = - boundMapAttrs[boundIdx + boundType].cast().getValue(); + map = mlir::cast(boundMapAttrs[boundIdx + boundType]) + .getValue(); operands.insert( operands.end(), operandItr, operandItr + map.getNumInputs()); std::advance(operandItr, map.getNumInputs()); diff --git a/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp b/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp index 8aef9629f1..352bc9d6be 100644 --- a/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp +++ b/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp @@ -50,9 +50,9 @@ class KrnlCopyFromBufferLowering : public ConversionPattern { Value destMemref(operandAdaptor.getDest()); ValueRange startVals(operandAdaptor.getStarts()); int64_t destRank = - destMemref.getType().cast().getShape().size(); + mlir::cast(destMemref.getType()).getShape().size(); int64_t buffRank = - buffMemref.getType().cast().getShape().size(); + mlir::cast(buffMemref.getType()).getShape().size(); int64_t destOffset = destRank - buffRank; assert(destOffset >= 0 && "offset expected non negative"); diff --git a/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp b/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp index e4e5399f29..a08766d63a 100644 --- a/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp +++ b/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp @@ -51,9 +51,9 @@ class KrnlCopyToBufferLowering : public ConversionPattern { ValueRange startVals(operandAdaptor.getStarts()); Value padVal(operandAdaptor.getPadValue()); int64_t srcRank = - sourceMemref.getType().cast().getShape().size(); + mlir::cast(sourceMemref.getType()).getShape().size(); int64_t buffRank = - buffMemref.getType().cast().getShape().size(); + mlir::cast(buffMemref.getType()).getShape().size(); int64_t srcOffset = srcRank - buffRank; assert(srcOffset >= 0 && "offset expected non negative"); SmallVector starts, bufferReadUBs, bufferPadUBs, pads, diff --git a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp index 769a0771f6..c33bc72fc0 100644 --- a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp @@ -54,8 +54,8 @@ class KrnlMatmulLowering : public ConversionPattern { bool fullUnrollAndJam = matmulOp.getUnroll(); // Operands and types. - Type elementType = - operandAdaptor.getA().getType().cast().getElementType(); + Type elementType = mlir::cast(operandAdaptor.getA().getType()) + .getElementType(); bool simdize = matmulOp.getSimdize(); // Init scope and emit constants. Location loc = matmulOp.getLoc(); @@ -241,7 +241,7 @@ class KrnlMatmulLowering : public ConversionPattern { /* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) { genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, - vectorLen, fullUnrollAndJam); + vectorLen, fullUnrollAndJam); }, /* has some partial tiles */ [&](AffineBuilderKrnlMem &createAffine) { // Trip regardless of full/partial for N & K // Test if SIMD dim (M) is full. @@ -271,7 +271,7 @@ class KrnlMatmulLowering : public ConversionPattern { /* then full */ [&](AffineBuilderKrnlMem &createAffine) { genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, - fullUnrollAndJam); + fullUnrollAndJam); }, /* else partial */ [&](AffineBuilderKrnlMem &createAffine) { genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, false); diff --git a/src/Conversion/KrnlToAffine/KrnlMemset.cpp b/src/Conversion/KrnlToAffine/KrnlMemset.cpp index 40f0455550..cb536d89ce 100644 --- a/src/Conversion/KrnlToAffine/KrnlMemset.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMemset.cpp @@ -44,7 +44,7 @@ class KrnlMemsetLowering : public ConversionPattern { // If delayed but the input memref has not normalized yet, do nothing. if (delayed && - !destMemRef.getType().cast().getLayout().isIdentity()) + !mlir::cast(destMemRef.getType()).getLayout().isIdentity()) return failure(); MultiDialectBuilder create( diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp index 6297d92a90..608c4e0d96 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -277,9 +277,8 @@ void recordInputOutputMemRefTypes(ModuleOp &module, auto *entryPointFunc = module.lookupSymbol(entryPointFuncName); assert(entryPointFunc && isa(entryPointFunc) && "entry point func must exist and be an llvm func op"); - auto entryPointTy = dyn_cast(entryPointFunc) - .getFunctionType() - .dyn_cast(); + auto entryPointTy = mlir::dyn_cast( + dyn_cast(entryPointFunc).getFunctionType()); SmallVector inputTypes, outputTypes; for (Type ty : entryPointTy.getInputs()) inputTypes.emplace_back(dyn_cast(ty)); @@ -414,10 +413,10 @@ void genSignatureFunction(ModuleOp &module, LLVM::GlobalOp globalEntryPoint = entryGlobalOps[j]; LLVM::GlobalOp globalSignature = (i == 0) ? inSigGlobalOps[j] : outSigGlobalOps[j]; - assert(globalEntryPoint.getValueAttr().isa() && + assert(mlir::isa(globalEntryPoint.getValueAttr()) && "Entry point value is not StringAttr"); StringAttr entryPointValueAttr = - globalEntryPoint.getValueAttr().cast(); + mlir::cast(globalEntryPoint.getValueAttr()); // Return the signature if found. create.llvm.ifThenElse(/*cond=*/ @@ -492,7 +491,7 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, if (rawData.empty()) return WalkResult::advance(); - auto valueAttr = op.getValue().value().cast(); + auto valueAttr = mlir::cast(op.getValue().value()); if (valueAttr.isSplat() || rawData.size() <= singleThreshold) return WalkResult::advance(); @@ -624,7 +623,9 @@ void loadConstantsFromFile(ModuleOp &module, bool zOS = isZOS(module); for (auto entryGlobalOp : entryGlobalOps) { std::string entryName = - entryGlobalOp.getValue().value().cast().getValue().str(); + mlir::cast(entryGlobalOp.getValue().value()) + .getValue() + .str(); // Entry point name is encoded in EBCDIC on z/OS. entryName = (zOS) ? krnl::e2a_s(entryName) : entryName; // Erase the null symbol. @@ -663,9 +664,7 @@ void loadConstantsFromFile(ModuleOp &module, LLVMBuilder::SymbolPostfix(module, EXTERNAL_CONSTANT_PREFIX + "filesize"); auto fsizeGlobalOp = module.lookupSymbol(fsizeSymbol); assert(fsizeGlobalOp && "Could not find the global op for filesize"); - int64_t dataSize = fsizeGlobalOp.getValue() - .value() - .cast() + int64_t dataSize = mlir::cast(fsizeGlobalOp.getValue().value()) .getValue() .getSExtValue(); // Get the global op for isLE. @@ -673,9 +672,7 @@ void loadConstantsFromFile(ModuleOp &module, LLVMBuilder::SymbolPostfix(module, EXTERNAL_CONSTANT_PREFIX + "isLE"); auto isleGlobalOp = module.lookupSymbol(isleSymbol); assert(isleGlobalOp && "Could not find the global op for data isle"); - int64_t isle = isleGlobalOp.getValue() - .value() - .cast() + int64_t isle = mlir::cast(isleGlobalOp.getValue().value()) .getValue() .getSExtValue(); // Get the packedConst global. @@ -704,9 +701,7 @@ void loadConstantsFromFile(ModuleOp &module, EXTERNAL_CONSTANT_PREFIX + "offset" + constantName; auto offsetGlobalOp = module.lookupSymbol(offsetSymbol); assert(offsetGlobalOp && "Could not find the global op for offset"); - int64_t offset = offsetGlobalOp.getValue() - .value() - .cast() + int64_t offset = mlir::cast(offsetGlobalOp.getValue().value()) .getValue() .getSExtValue(); diff --git a/src/Conversion/KrnlToLLVM/KrnlCall.cpp b/src/Conversion/KrnlToLLVM/KrnlCall.cpp index bd07d42c5e..6d570ab3d4 100644 --- a/src/Conversion/KrnlToLLVM/KrnlCall.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlCall.cpp @@ -60,9 +60,9 @@ class KrnlCallOpLowering : public ConversionPattern { // Handle the Attributes for (auto namedAttr : op->getAttrs()) { // Avoid the funcName() Attribute - if (namedAttr.getName().getValue().equals("funcName")) + if (namedAttr.getName().getValue() == "funcName") continue; - if (namedAttr.getName().getValue().equals("numOfOutput")) + if (namedAttr.getName().getValue() == "numOfOutput") continue; handleOneAttribute( rewriter, op, namedAttr.getValue(), parameterTypeList, parameterList); @@ -104,7 +104,7 @@ class KrnlCallOpLowering : public ConversionPattern { Type ty = original.getType(); if (auto originalMemRef = dyn_cast(ty)) { auto int64Ty = IntegerType::get(context, 64); - auto memRefTy = parameter.getType().dyn_cast(); + auto memRefTy = mlir::dyn_cast(parameter.getType()); auto memRefRank = krnl::getRankFromMemRefType(memRefTy); auto memRefRankVal = create.llvm.constant(int64Ty, (int64_t)memRefRank); Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, @@ -119,7 +119,7 @@ class KrnlCallOpLowering : public ConversionPattern { parameterTypeList.emplace_back(opaquePtrTy); parameterList.emplace_back(omTensor); omTensors.emplace_back(omTensor); - } else if (ty.isa()) { + } else if (mlir::isa(ty)) { // Generate llvm null pinter for NoneType auto int8Ty = IntegerType::get(context, 8); auto opaquePtrTy = getPointerType(context, int8Ty); @@ -176,7 +176,7 @@ class KrnlCallOpLowering : public ConversionPattern { // In future, the attributes should be converted in krnl.call builder. // This code passed onnx-mlir-opt --convert-krnl-to-llvm test case, // but failed in onnx-milr for the tensor type for the attribute - auto tensorTy = denseAttr.getType().cast(); + auto tensorTy = mlir::cast(denseAttr.getType()); auto memRefTy = MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()); Value constantGlobal = diff --git a/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp b/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp index 7180bfb8bd..47eaa8bb53 100644 --- a/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp @@ -148,14 +148,16 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // OMInitCompatibleAccelX's signature is `i64 (i64)`. if (Attribute maccelAttr = module->getAttrOfType<::mlir::Attribute>("onnx-mlir.accels")) { - assert( - maccelAttr.isa() && "onnx-mlir.accels must be ArrayAttr"); - ArrayAttr accels = maccelAttr.cast(); + assert(mlir::isa(maccelAttr) && + "onnx-mlir.accels must be ArrayAttr"); + ArrayAttr accels = mlir::cast(maccelAttr); Value zeroI64 = create.llvm.constant(int64Ty, (int64_t)0); for (uint64_t i = 0; i < accels.size(); ++i) { - assert(accels[i].isa() && "Attribute must be StringAttr"); - StringRef accelStr = accels.getValue()[i].cast().getValue(); + assert( + mlir::isa(accels[i]) && "Attribute must be StringAttr"); + StringRef accelStr = + mlir::cast(accels.getValue()[i]).getValue(); std::pair NameAndVersion = accelStr.split('-'); uint64_t versionNumberInHex = std::stoul(NameAndVersion.second.str(), nullptr, 16); @@ -199,9 +201,8 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // struct but input types are unpacked into a single list of scalar types. auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName.lower()); - auto staticEntryPointFuncTy = cast(staticEntryPointFunc) - .getFunctionType() - .cast(); + auto staticEntryPointFuncTy = mlir::cast( + cast(staticEntryPointFunc).getFunctionType()); LLVM_DEBUG(llvm::dbgs() << "Static entry point function type: " << staticEntryPointFuncTy << "\n"); // Static entry point is wrapped with prefix `_mlir_ciface` automatically by @@ -216,8 +217,8 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { "entry point func must exist and be an llvm func op"); auto wrappedStaticEntryPointOp = cast(wrappedStaticEntryPointFunc); - auto wrappedStaticEntryPointTy = wrappedStaticEntryPointOp.getFunctionType() - .cast(); + auto wrappedStaticEntryPointTy = mlir::cast( + wrappedStaticEntryPointOp.getFunctionType()); Value omTensorPtrArr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::GET_OMT_ARRAY, {omTensorInputs}); @@ -264,7 +265,8 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // output. create.llvm.call({}, wrappedStaticEntryPointFuncName, staticInputs); Value outMemRefs = create.llvm.load(memRefOutTy, ptrToOutMemRef); - auto outMemRefsType = outMemRefs.getType().dyn_cast(); + auto outMemRefsType = + mlir::dyn_cast(outMemRefs.getType()); std::vector outMemRefList; if (numOutputs == 1) { @@ -292,7 +294,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Get the i-th memref returned, convert to a dynamic memref and store it // in the wrappedOutput. Value memRef = outMemRefList.at(i); - auto outMemRefTy = memRef.getType().dyn_cast(); + auto outMemRefTy = mlir::dyn_cast(memRef.getType()); int64_t outMemRefRank = krnl::getRankFromMemRefType(outMemRefTy); Value outMemRefRankVal = create.llvm.constant(int64Ty, (int64_t)outMemRefRank); @@ -328,7 +330,8 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { LLVM::LLVMFuncOp &dynamicEntryPointFunc, Location &loc) const { // Add entry block: auto *entryPointEntryBlock = new Block(); - auto dynEntryPointFuncType = dynEntryPoint.cast(); + auto dynEntryPointFuncType = + mlir::cast(dynEntryPoint); dynamicEntryPointFunc.push_back(entryPointEntryBlock); llvm::SmallVector argTypes; for (size_t i = 0; i < dynEntryPointFuncType.getNumParams(); i++) @@ -352,7 +355,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { Value dataPtr = RuntimeAPI::callApi( rewriter, loc, apiRegistry, RuntimeAPI::API::GET_DATA, {rtMemRef}); dataPtr = create.llvm.bitcast( - memRefTy.cast().getBody()[0], dataPtr); + mlir::cast(memRefTy).getBody()[0], dataPtr); memRef = create.llvm.insertValue(memRefTy, memRef, dataPtr, {0}); memRef = create.llvm.insertValue(memRefTy, memRef, dataPtr, {1}); @@ -362,7 +365,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Get rank, sizes array ptr and strides array ptr. auto rank = - krnl::getRankFromMemRefType(memRefTy.cast()); + krnl::getRankFromMemRefType(mlir::cast(memRefTy)); Value sizesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::GET_DATA_SHAPE, {rtMemRef}); Value stridesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, diff --git a/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp b/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp index 7c923515ac..667683d826 100644 --- a/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp @@ -63,14 +63,12 @@ class KrnlFindIndexOpLowering : public ConversionPattern { llvm_unreachable("unexpected inputType"); }); - Type GType = operandAdaptor.getG() - .getType() - .cast() - .getBody()[1]; - Type VType = operandAdaptor.getV() - .getType() - .cast() - .getBody()[1]; + Type GType = + mlir::cast(operandAdaptor.getG().getType()) + .getBody()[1]; + Type VType = + mlir::cast(operandAdaptor.getV().getType()) + .getBody()[1]; // Remaining operands. Value extractedGPtr = diff --git a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp b/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp index 1c8a145a40..ac0d8dc70b 100644 --- a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp @@ -55,19 +55,19 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { // The element type of the array. const Type type = op->getResult(0).getType(); - const MemRefType memRefTy = type.cast(); + const MemRefType memRefTy = mlir::cast(type); const Type constantElementType = typeConverter->convertType(memRefTy.getElementType()); Type globalType = constantElementType; // The llvm type of the global (example: [2 x [8 x float]]). - const auto shape = (krnlGlobalOp.getShape()).dyn_cast(); + const auto shape = mlir::dyn_cast(krnlGlobalOp.getShape()); if (shape.empty()) - globalType = LLVM::LLVMArrayType::get(globalType.cast(), 1); + globalType = LLVM::LLVMArrayType::get(mlir::cast(globalType), 1); else { for (int i = shape.size() - 1; i >= 0; i--) globalType = LLVM::LLVMArrayType::get( - globalType.cast(), ArrayAttrIntVal(shape, i)); + mlir::cast(globalType), ArrayAttrIntVal(shape, i)); } // Create the global at the entry of the module. @@ -112,15 +112,16 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { private: static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { - return (a.getValue()[i]).cast().getInt(); + return mlir::cast(a.getValue()[i]).getInt(); } LLVM::GlobalOp lowerDenseResourceConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType, ConversionPatternRewriter &rewriter) const { assert(krnlGlobalOp.getValue().has_value() && "Expecting KrnlGlobalOp with a valid value"); - assert(krnlGlobalOp.getValue().value().isa() && - "Expecting a global with an dense resource elements attribute"); + assert( + mlir::isa(krnlGlobalOp.getValue().value()) && + "Expecting a global with an dense resource elements attribute"); MLIRContext *context = krnlGlobalOp.getContext(); Location loc = krnlGlobalOp.getLoc(); @@ -130,11 +131,10 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - auto blob = krnlGlobalOp.getValue() - .value() - .cast() - .getRawHandle() - .getBlob(); + auto blob = + mlir::cast(krnlGlobalOp.getValue().value()) + .getRawHandle() + .getBlob(); assert(blob && "Expecting dense resource with a valid blob"); ArrayRef rawData = blob->getData(); @@ -158,7 +158,7 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const { assert(krnlGlobalOp.getValue().has_value() && "Expecting KrnlGlobalOp with a valid value"); - assert(krnlGlobalOp.getValue().value().isa() && + assert(mlir::isa(krnlGlobalOp.getValue().value()) && "Expecting a global with an dense elements attribute"); Location loc = krnlGlobalOp.getLoc(); @@ -172,11 +172,11 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { rewriter.setInsertionPointToStart(module.getBody()); DenseElementsAttr denseAttr = - krnlGlobalOp.getValue().value().cast(); + mlir::cast(krnlGlobalOp.getValue().value()); uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); LLVM::GlobalOp global; - if ((!denseAttr.getElementType().isa()) && + if (!(mlir::isa(denseAttr.getElementType())) && (!denseAttr.isSplat()) && (sizeInBytes > 1024)) { ArrayRef rawData = denseAttr.getRawData(); assert( @@ -189,7 +189,7 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(), llvmStringAttr); } else { - if (denseAttr.getElementType().isa()) + if (mlir::isa(denseAttr.getElementType())) global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter); else global = create.llvm.globalOp(globalType, @@ -246,13 +246,13 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { uint64_t computeSizeInBytes(KrnlGlobalOp &krnlGlobalOp) const { // Compute total number of elements. - const auto shape = (krnlGlobalOp.getShape()).dyn_cast(); + const auto shape = mlir::dyn_cast(krnlGlobalOp.getShape()); uint64_t numElements = 1; for (unsigned int i = 0; i < shape.size(); ++i) numElements *= ArrayAttrIntVal(shape, i); const auto type = krnlGlobalOp.getResult().getType(); - const auto memRefTy = type.cast(); + const auto memRefTy = mlir::cast(type); // Special handling for bool. if (memRefTy.getElementType().isInteger(1)) @@ -283,14 +283,14 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { // the address of the global strings into an array. Return the array address. LLVM::GlobalOp lowerStringLiteral( KrnlGlobalOp &krnlGlobalOp, Type globalType, OpBuilder &builder) const { - assert(krnlGlobalOp.getValue().value().isa() && + assert(mlir::isa(krnlGlobalOp.getValue().value()) && "Expecting a dense value"); Location loc = krnlGlobalOp.getLoc(); MultiDialectBuilder create(builder, loc); DenseElementsAttr denseAttr = - krnlGlobalOp.getValue().value().cast(); + mlir::cast(krnlGlobalOp.getValue().value()); Type i8PtrType = getI8PointerType(builder.getContext()); diff --git a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp index 4590221588..7312be6e61 100644 --- a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp @@ -61,15 +61,15 @@ class KrnlInstrumentOpLowering : public ConversionPattern { StringRef nodeName; if (instrumentOp.getNodeName().has_value()) nodeName = instrumentOp.getNodeName().value(); - else if (auto nameLoc = loc.dyn_cast()) + else if (auto nameLoc = mlir::dyn_cast(loc)) nodeName = nameLoc.getName(); - else if (auto fusedLoc = loc.dyn_cast()) { + else if (auto fusedLoc = mlir::dyn_cast(loc)) { // Combine each location name and set it as nodeName, appended by "-". std::string name; for (Location locIt : fusedLoc.getLocations()) { - if (auto nameLocIt = locIt.dyn_cast()) + if (auto nameLocIt = mlir::dyn_cast(locIt)) name += nameLocIt.getName().str() + "-"; - else if (auto fileLineColLoc = locIt.dyn_cast()) { + else if (auto fileLineColLoc = mlir::dyn_cast(locIt)) { std::string filename = llvm::sys::path::filename(fileLineColLoc.getFilename().str()) .str(); @@ -83,7 +83,7 @@ class KrnlInstrumentOpLowering : public ConversionPattern { name.pop_back(); // remove last "-" Location newLoc = NameLoc::get(rewriter.getStringAttr(name)); nodeName = cast(newLoc).getName(); - } else if (auto fileLineColLoc = loc.dyn_cast()) { + } else if (auto fileLineColLoc = mlir::dyn_cast(loc)) { std::string filename = llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str(); std::string name = diff --git a/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp b/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp index 4e15506468..d2aa7c1f35 100644 --- a/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp @@ -57,9 +57,10 @@ class KrnlMemcpyOpLowering : public ConversionPattern { Type i1Ty = IntegerType::get(context, 1); Type i64Ty = IntegerType::get(context, 64); Type i8PtrTy = getPointerType(context, IntegerType::get(context, 8)); - Type elementType = src.getType().cast().getBody()[1]; + Type elementType = + mlir::cast(src.getType()).getBody()[1]; int64_t eltSize = getMemRefEltSizeInBytes( - memcpyOp.getSrc().getType().dyn_cast()); + mlir::dyn_cast(memcpyOp.getSrc().getType())); Value eltSizeInBytes = create.llvm.constant(i64Ty, eltSize); // Get a symbol reference to the memcpy function, inserting it if necessary. diff --git a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp index 95d33c0341..f254dc1074 100644 --- a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp @@ -45,7 +45,7 @@ class KrnlPrintTensorOpLowering : public ConversionPattern { StringRef msg = printTensorOp.getMsg(); Value input = operandAdaptor.getInput(); Value originalInput = printTensorOp.getInput(); - assert(input.getType().isa() && + assert(mlir::isa(input.getType()) && "expecting LLVMStructType"); ModuleOp module = printTensorOp->getParentOfType(); @@ -55,13 +55,14 @@ class KrnlPrintTensorOpLowering : public ConversionPattern { // Get a symbol reference to the runtime function to use, creating one if // necessary. auto int64Ty = IntegerType::get(context, 64); - auto memRefTy = input.getType().dyn_cast(); + auto memRefTy = mlir::dyn_cast(input.getType()); auto memRefRank = krnl::getRankFromMemRefType(memRefTy); Value memRefRankVal = create.llvm.constant(int64Ty, (int64_t)memRefRank); Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); - Type elemTy = originalInput.getType().cast().getElementType(); + Type elemTy = + mlir::cast(originalInput.getType()).getElementType(); krnl::fillOMTensorWithMemRef(input, elemTy, omTensor, false /*outOwning*/, rewriter, loc, apiRegistry, module); LLVM::GlobalOp globalStr = krnl::getOrCreateGlobalString( diff --git a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp index 7b3772c485..0e2ece621c 100644 --- a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp @@ -50,10 +50,9 @@ class KrnlRandomNormalOpLowering : public ConversionPattern { getOrInsertRandomNormal(rewriter, parentModule, inType); // First operand. - Type outputType = operandAdaptor.getOutput() - .getType() - .cast() - .getBody()[1]; + Type outputType = + mlir::cast(operandAdaptor.getOutput().getType()) + .getBody()[1]; Value alignedOutput = create.llvm.extractValue(outputType, operandAdaptor.getOutput(), {1}); diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp index 8924baeef4..979eb3db50 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp @@ -119,10 +119,10 @@ int64_t getRankFromMemRefType(LLVM::LLVMStructType memRefTy) { assert((numElems == 3 || numElems == 5) && "Expect MemRef type to contain either 3 or 5 elements."); - return (numElems == 3) ? 0 // MemRef refers to a scalar. - : memRefTy.getBody()[3] - .cast() - .getNumElements(); + return (numElems == 3) + ? 0 // MemRef refers to a scalar. + : mlir::cast(memRefTy.getBody()[3]) + .getNumElements(); } // Convert an MLIR type to the correspoding ONNX type. @@ -137,7 +137,7 @@ void fillOMTensorWithMemRef(Value &outMemRef, Type elemTy, Value &outOMTensor, int64_t outOwning, PatternRewriter &rewriter, const Location &loc, const RuntimeAPIRegistry &apiRegistry, ModuleOp &module) { MLIRContext *context = module.getContext(); - auto outMemRefTy = outMemRef.getType().dyn_cast(); + auto outMemRefTy = mlir::dyn_cast(outMemRef.getType()); auto int64Ty = IntegerType::get(context, 64); MultiDialectBuilder create(rewriter, loc); @@ -293,7 +293,9 @@ Operation *getFirstEntryOpInBlock( Operation *firstEntryPointOp = nullptr; for (auto entryGlobalOp : entryGlobalOps) { std::string entryName = - entryGlobalOp.getValue().value().cast().getValue().str(); + mlir::cast(entryGlobalOp.getValue().value()) + .getValue() + .str(); // Entry point name is encoded in EBCDIC on z/OS. entryName = isZOS(module) ? krnl::e2a_s(entryName) : entryName; @@ -315,14 +317,15 @@ ArrayRef getRawData(KrnlGlobalOp &op) { auto value = op.getValue().value(); TypeSwitch(value) .Case([&](DenseResourceElementsAttr attr) { - auto blob = - value.cast().getRawHandle().getBlob(); + auto blob = mlir::cast(value) + .getRawHandle() + .getBlob(); assert(blob && "Expecting dense resource with a valid blob"); rawData = blob->getData(); }) .Case([&](DenseElementsAttr attr) { DenseElementsAttr denseAttr = - value.dyn_cast_or_null(); + mlir::dyn_cast_or_null(value); rawData = denseAttr.getRawData(); }) .Default([&](Attribute attr) { return; }); @@ -333,7 +336,8 @@ bool isZOS(ModuleOp module) { bool zOS = false; if (Attribute mtripleAttr = module->getAttrOfType<::mlir::Attribute>("llvm.target_triple")) - zOS = llvm::Triple(mtripleAttr.cast().getValue()).isOSzOS(); + zOS = + llvm::Triple(mlir::cast(mtripleAttr).getValue()).isOSzOS(); return zOS; } diff --git a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp index fc963d50e7..256d00572a 100644 --- a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp @@ -43,7 +43,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto krnlVectorTypeCastOp = cast(op); MemRefType sourceType = - krnlVectorTypeCastOp.getOperand().getType().cast(); + mlir::cast(krnlVectorTypeCastOp.getOperand().getType()); MemRefType targetType = krnlVectorTypeCastOp.getType(); if (!isSupportedMemRefType(targetType) || !isSupportedMemRefType(sourceType)) @@ -114,7 +114,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { // There is the implicit expectation that the last dimension of the // original memory is a multiple of the vector length. Value vecWidth = createIndexAttrConstant(rewriter, loc, indexType, - targetType.getElementType().cast().getNumElements()); + mlir::cast(targetType.getElementType()).getNumElements()); sizes.push_back(rewriter.create(loc, srcMemRefDesc.size(rewriter, loc, sourceType.getRank() - 1), vecWidth)); diff --git a/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp b/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp index 9acd72236a..5d810aa6eb 100644 --- a/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp +++ b/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp @@ -47,7 +47,7 @@ Value RuntimeAPI::callApi(OpBuilder &builder, Location loc, SmallVector outputTys; const RuntimeAPI &runtimeAPI = registry.getAPI(apiId); auto outputTy = runtimeAPI.outputTy; - if (!outputTy.isa()) + if (!mlir::isa(outputTy)) outputTys.emplace_back(outputTy); return create.llvm.call(ArrayRef(outputTys), registry.getAPI(apiId).symbolRef, ArrayRef(params)); diff --git a/src/Conversion/ONNXConversionCommon/RNN/LSTM.cpp b/src/Conversion/ONNXConversionCommon/RNN/LSTM.cpp index 3de32358f6..81f89839ed 100644 --- a/src/Conversion/ONNXConversionCommon/RNN/LSTM.cpp +++ b/src/Conversion/ONNXConversionCommon/RNN/LSTM.cpp @@ -50,15 +50,15 @@ getActivationPack(ONNXLSTMOp *op) { // Forward activations. if (activationArrAttr.size() > 0) { activationForward.f.name = - activationArrAttr[0].cast().getValue(); + mlir::cast(activationArrAttr[0]).getValue(); } if (activationArrAttr.size() > 1) { activationForward.g.name = - activationArrAttr[1].cast().getValue(); + mlir::cast(activationArrAttr[1]).getValue(); } if (activationArrAttr.size() > 2) { activationForward.h.name = - activationArrAttr[2].cast().getValue(); + mlir::cast(activationArrAttr[2]).getValue(); } } @@ -67,15 +67,17 @@ getActivationPack(ONNXLSTMOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 3; if (activationArrAttr.size() > startIndex) { activationReverse.f.name = - activationArrAttr[startIndex].cast().getValue(); + mlir::cast(activationArrAttr[startIndex]).getValue(); } if (activationArrAttr.size() > startIndex + 1) { activationReverse.g.name = - activationArrAttr[startIndex + 1].cast().getValue(); + mlir::cast(activationArrAttr[startIndex + 1]) + .getValue(); } if (activationArrAttr.size() > startIndex + 2) { activationReverse.h.name = - activationArrAttr[startIndex + 2].cast().getValue(); + mlir::cast(activationArrAttr[startIndex + 2]) + .getValue(); } } } @@ -86,13 +88,13 @@ getActivationPack(ONNXLSTMOp *op) { if (direction == FORWARD || direction == BIDIRECTIONAL) { // Forward activations. if (activationArrAttr.size() > 0) { - activationForward.f.alpha = activationArrAttr[0].cast(); + activationForward.f.alpha = mlir::cast(activationArrAttr[0]); } if (activationArrAttr.size() > 1) { - activationForward.g.alpha = activationArrAttr[1].cast(); + activationForward.g.alpha = mlir::cast(activationArrAttr[1]); } if (activationArrAttr.size() > 2) { - activationForward.h.alpha = activationArrAttr[2].cast(); + activationForward.h.alpha = mlir::cast(activationArrAttr[2]); } } @@ -101,15 +103,15 @@ getActivationPack(ONNXLSTMOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 3; if (activationArrAttr.size() > startIndex) { activationReverse.f.alpha = - activationArrAttr[startIndex].cast(); + mlir::cast(activationArrAttr[startIndex]); } if (activationArrAttr.size() > startIndex + 1) { activationReverse.g.alpha = - activationArrAttr[startIndex + 1].cast(); + mlir::cast(activationArrAttr[startIndex + 1]); } if (activationArrAttr.size() > startIndex + 2) { activationReverse.h.alpha = - activationArrAttr[startIndex + 2].cast(); + mlir::cast(activationArrAttr[startIndex + 2]); } } } @@ -120,13 +122,13 @@ getActivationPack(ONNXLSTMOp *op) { if (direction == FORWARD || direction == BIDIRECTIONAL) { // Forward activations. if (activationArrAttr.size() > 0) { - activationForward.f.beta = activationArrAttr[0].cast(); + activationForward.f.beta = mlir::cast(activationArrAttr[0]); } if (activationArrAttr.size() > 1) { - activationForward.g.beta = activationArrAttr[1].cast(); + activationForward.g.beta = mlir::cast(activationArrAttr[1]); } if (activationArrAttr.size() > 2) { - activationForward.h.beta = activationArrAttr[2].cast(); + activationForward.h.beta = mlir::cast(activationArrAttr[2]); } } @@ -135,15 +137,15 @@ getActivationPack(ONNXLSTMOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 3; if (activationArrAttr.size() > startIndex) { activationReverse.f.beta = - activationArrAttr[startIndex].cast(); + mlir::cast(activationArrAttr[startIndex]); } if (activationArrAttr.size() > startIndex + 1) { activationReverse.g.beta = - activationArrAttr[startIndex + 1].cast(); + mlir::cast(activationArrAttr[startIndex + 1]); } if (activationArrAttr.size() > startIndex + 2) { activationReverse.h.beta = - activationArrAttr[startIndex + 2].cast(); + mlir::cast(activationArrAttr[startIndex + 2]); } } } diff --git a/src/Conversion/ONNXConversionCommon/RNN/RNNBase.cpp b/src/Conversion/ONNXConversionCommon/RNN/RNNBase.cpp index 5ab59dd830..f5680ca301 100644 --- a/src/Conversion/ONNXConversionCommon/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXConversionCommon/RNN/RNNBase.cpp @@ -21,7 +21,7 @@ namespace onnx_mlir { // Get a dimension of the tensor's shape. int64_t dimAt(Value val, int index) { - return val.getType().cast().getShape()[index]; + return mlir::cast(val.getType()).getShape()[index]; } // Apply an activation function on a given scalar operand. diff --git a/src/Conversion/ONNXToKrnl/Additional/Custom.cpp b/src/Conversion/ONNXToKrnl/Additional/Custom.cpp index 7b3b72b1c5..952760a41e 100644 --- a/src/Conversion/ONNXToKrnl/Additional/Custom.cpp +++ b/src/Conversion/ONNXToKrnl/Additional/Custom.cpp @@ -47,7 +47,7 @@ struct ONNXCustomOpLowering : public OpConversionPattern { for (size_t idx = 0; idx < op->getResultTypes().size(); idx++) { Type ty = op->getResultTypes()[idx]; MemRefType outputMemRefType = - typeConverter->convertType(ty).cast(); + mlir::cast(typeConverter->convertType(ty)); outputMemRefTypes.emplace_back(outputMemRefType); Value alloc = create.mem.alignedAlloc( outputMemRefType, shapeHelper.getOutputDims(idx)); diff --git a/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp b/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp index 9c7bab5533..17a435b7c6 100644 --- a/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp +++ b/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp @@ -48,15 +48,15 @@ struct ONNXLayoutTransformOpLowering // Convert the input type to MemRefType. Type inConvertedType = typeConverter->convertType(data.getType()); - assert(inConvertedType && inConvertedType.isa() && + assert(inConvertedType && mlir::isa(inConvertedType) && "Failed to convert type to MemRefType"); - MemRefType inMemRefType = inConvertedType.cast(); + MemRefType inMemRefType = mlir::cast(inConvertedType); // Convert the output type to MemRefType. Type outputTensorType = *op->result_type_begin(); Type outConvertedType = typeConverter->convertType(outputTensorType); - assert(outConvertedType && outConvertedType.isa() && + assert(outConvertedType && mlir::isa(outConvertedType) && "Failed to convert type to MemRefType"); - MemRefType outMemRefType = outConvertedType.cast(); + MemRefType outMemRefType = mlir::cast(outConvertedType); // Note that by definition the input and output of LayoutTransformOp have // the same logical rank. The only difference between them should be their diff --git a/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp b/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp index 018d8a0188..1f13fd0656 100644 --- a/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp +++ b/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp @@ -43,9 +43,9 @@ struct ONNXShapeTransformOpLowering : public ConversionPattern { shapeHelper.computeShapeAndAssertOnFailure(); // Input and output types. - MemRefType inputMemRefType = input.getType().cast(); - MemRefType outputMemRefType = - typeConverter->convertType(*op->result_type_begin()).cast(); + MemRefType inputMemRefType = mlir::cast(input.getType()); + MemRefType outputMemRefType = mlir::cast( + typeConverter->convertType(*op->result_type_begin())); uint64_t inputRank = inputMemRefType.getRank(); uint64_t outputRank = outputMemRefType.getRank(); diff --git a/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp b/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp index 0beb78a62e..01249466ff 100644 --- a/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp +++ b/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp @@ -68,7 +68,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // all loop iterations. for (unsigned long i = 0; i < adaptor.getVInitial().size(); i++) { Value origInput = loopOp.getVInitial()[i]; - if (origInput.getType().isa()) { + if (mlir::isa(origInput.getType())) { Value zero = create.math.constantIndex(0); create.krnl.store(adaptor.getVInitial()[i], outputs[i], zero); } else { @@ -79,9 +79,9 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // Convert the cond type to MemRefType. Type convertedType = typeConverter->convertType(adaptor.getCond().getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType condMemRefTy = convertedType.cast(); + MemRefType condMemRefTy = mlir::cast(convertedType); // Create a memref for recording loop condition, initialize it with the // initial loop condition. @@ -129,7 +129,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // For SeqType, load the value for the storage for (unsigned long i = 0; i < loopOp.getVInitial().size(); i++) { - if (loopOp.getVInitial()[i].getType().isa()) { + if (mlir::isa(loopOp.getVInitial()[i].getType())) { Value seqValue = create.krnl.load(outputs[i], zero); params.emplace_back(seqValue); } else { @@ -219,11 +219,10 @@ struct ONNXLoopOpLowering : public OpConversionPattern { outputs.begin() + adaptor.getVInitial().size(), outputs.end()); for (auto scanIntermediateToFinal : llvm::zip(scanIntermediate, scanOutputs)) { - Type elementType = std::get<1>(scanIntermediateToFinal) - .getType() - .cast() + Type elementType = mlir::cast( + std::get<1>(scanIntermediateToFinal).getType()) .getElementType(); - if (elementType.dyn_cast()) { + if (mlir::dyn_cast(elementType)) { // accumulate dynamic tensor rewriter.create(loc, std::get<0>(scanIntermediateToFinal), @@ -239,7 +238,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // outside the iteration scope so next iteration can use them as init // value. for (unsigned long i = 0; i < loopOp.getVInitial().size(); i++) { - if (loopOp.getVInitial()[i].getType().isa()) { + if (mlir::isa(loopOp.getVInitial()[i].getType())) { create.krnl.store(bodyOutputs[i + 1], outputs[i], zero); } else { emitCopy(rewriter, loc, bodyOutputs[i + 1], outputs[i]); @@ -267,8 +266,8 @@ struct ONNXLoopOpLowering : public OpConversionPattern { for (unsigned long i = 0; i < outputs.size(); i++) { Value output = outputs[i]; auto seqElementType = - output.getType().cast().getElementType(); - if (seqElementType.isa()) { + mlir::cast(output.getType()).getElementType(); + if (mlir::isa(seqElementType)) { // need to distinguish seqType in v_final and scan if (i < loopOp.v_final().size()) { // In v_final @@ -284,19 +283,21 @@ struct ONNXLoopOpLowering : public OpConversionPattern { create.krnl.load(output, create.math.constantIndex(0)); SmallVector allocParams; SmallVector dims; - dims.emplace_back(output.getType().cast().getShape()[0]); - if (output.getType().cast().isDynamicDim(0)) + dims.emplace_back( + mlir::cast(output.getType()).getShape()[0]); + if (mlir::cast(output.getType()).isDynamicDim(0)) allocParams.emplace_back(create.mem.dim(output, 0)); for (auto i = 0; - i < firstElement.getType().cast().getRank(); i++) { + i < mlir::cast(firstElement.getType()).getRank(); + i++) { dims.emplace_back( - firstElement.getType().cast().getShape()[i]); - if (firstElement.getType().cast().isDynamicDim(i)) + mlir::cast(firstElement.getType()).getShape()[i]); + if (mlir::cast(firstElement.getType()).isDynamicDim(i)) allocParams.emplace_back(create.mem.dim(firstElement, i)); } ArrayRef shape(dims.data(), dims.size()); auto flatType = MemRefType::get(shape, - firstElement.getType().cast().getElementType()); + mlir::cast(firstElement.getType()).getElementType()); Value alloc = create.mem.alignedAlloc(flatType, allocParams); // copy the value KrnlBuilder createKrnl(rewriter, loc); @@ -338,11 +339,11 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // Convert vFinal's type to MemRefType. Type convertedType = typeConverter->convertType(vFinal.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); - if (vFinal.getType().isa()) { + if (mlir::isa(vFinal.getType())) { memRefType = MemRefType::get({1}, memRefType); } @@ -363,9 +364,9 @@ struct ONNXLoopOpLowering : public OpConversionPattern { for (const auto &opScanOutput : loopOp.scan_outputs()) { // Convert opScanOutput's type to MemRefType. Type convertedType = typeConverter->convertType(opScanOutput.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Allocate memory for the scan outputs. There're no good "reference" // shape for scan outputs. So if the scan outputs do not have constant @@ -392,7 +393,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { if (isWhile) { llvm_unreachable("Scan output for while loop is not supported"); } - assert(!adaptor.getM().getType().isa()); + assert(!mlir::isa(adaptor.getM().getType())); Value maxTripCount = rewriter.create(loc, adaptor.getM()).getResult(); allocParams.emplace_back(rewriter.create( @@ -444,7 +445,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { const Value &src, const Value &dest, std::vector writePrefix = {}) const { OpBuilder::InsertionGuard insertGuard(rewriter); - auto srcTy = src.getType().cast(); + auto srcTy = mlir::cast(src.getType()); SmallVector readIV; MultiDialectBuilder create( rewriter, loc); @@ -505,8 +506,8 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // `ONNXYieldOp (cond, ..., ubValue, ..., newCounterValue, ...)` // which means the condition is loop invariant. Value breakCond = yieldOp->getOperands()[0]; - if (breakCond.isa() && - breakCond.cast().getArgNumber() == 1) { + if (mlir::isa(breakCond) && + mlir::cast(breakCond).getArgNumber() == 1) { } else return true; @@ -688,11 +689,11 @@ struct ONNXLoopOpLowering : public OpConversionPattern { resultsRange.begin(), resultsRange.end()); for (unsigned long i = 0; i < bodyOutputs.size(); i++) { Value output = bodyOutputs[i]; - assert((output.getType().isa() || - output.getType().isa()) && + assert((mlir::isa(output.getType()) || + mlir::isa(output.getType())) && "Expecting loop body function output to consist of " "tensors/memrefs."); - auto outputTy = output.getType().cast(); + auto outputTy = mlir::cast(output.getType()); bodyOutputs[i] = rewriter .create(loc, MemRefType::get(outputTy.getShape(), @@ -728,11 +729,10 @@ struct ONNXLoopOpLowering : public OpConversionPattern { for (auto scanIntermediateToFinal : llvm::zip(scanIntermediate, scanOutputs)) { - Type elementType = std::get<1>(scanIntermediateToFinal) - .getType() - .cast() + Type elementType = mlir::cast( + std::get<1>(scanIntermediateToFinal).getType()) .getElementType(); - if (elementType.dyn_cast()) { + if (mlir::dyn_cast(elementType)) { // TODO(chentong): handle dynamic scan output for while loop llvm_unreachable("Not implemented yet"); } else { diff --git a/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp b/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp index f4acaa4752..54b992c69c 100644 --- a/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp +++ b/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp @@ -145,11 +145,11 @@ struct ONNXScanOpLowering : public OpConversionPattern { resultsRange.begin(), resultsRange.end()); for (unsigned i = 0; i < bodyOutputs.size(); i++) { auto output = bodyOutputs[i]; - assert((output.getType().isa() || - output.getType().isa()) && + assert((mlir::isa(output.getType()) || + mlir::isa(output.getType())) && "Expecting scan body function output to consist of" "tensors/memrefs."); - auto outputTy = output.getType().cast(); + auto outputTy = mlir::cast(output.getType()); bodyOutputs[i] = rewriter .create(loc, MemRefType::get(outputTy.getShape(), @@ -210,9 +210,9 @@ struct ONNXScanOpLowering : public OpConversionPattern { // Convert vFinal's type to MemRefType. Type convertedType = typeConverter->convertType(vFinal.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Allocate memory for the loop-carried dependencies, since they are // guaranteed to have the same shape throughout all iterations, use @@ -231,9 +231,9 @@ struct ONNXScanOpLowering : public OpConversionPattern { for (const auto &opScanOutput : scanOp.scan_outputs()) { // Convert opScanOutput's type to MemRefType. Type convertedType = typeConverter->convertType(opScanOutput.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Allocate memory for the scan outputs. There're no good "reference" // shape for scan outputs. So if the scan outputs do not have constant @@ -279,9 +279,9 @@ struct ONNXScanOpLowering : public OpConversionPattern { mlir::Type bodyScanInputTy) { // Convert type to MemRefType. Type convertedType = typeConverter->convertType(bodyScanInputTy); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Allocate memory for the scan outputs. There're no good "reference" // shape for scan outputs. So if the scan outputs do not have constant @@ -311,7 +311,7 @@ struct ONNXScanOpLowering : public OpConversionPattern { std::vector writePrefix = {}) { OpBuilder::InsertionGuard insertGuard(builder); - auto srcTy = src.getType().cast(); + auto srcTy = mlir::cast(src.getType()); MultiDialectBuilder create( builder, loc); if (srcTy.getRank() > 0) { @@ -339,7 +339,7 @@ struct ONNXScanOpLowering : public OpConversionPattern { const Value &src, const Value &dest, std::vector readPrefix = {}) { OpBuilder::InsertionGuard insertGuard(builder); - auto srcTy = src.getType().cast(); + auto srcTy = mlir::cast(src.getType()); SmallVector readIV(readPrefix.begin(), readPrefix.end()); MultiDialectBuilder create( builder, loc); diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 3a450c4e65..5309a01092 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -85,7 +85,7 @@ class ONNXEntryPointLowering : public OpRewritePattern { .Case([&](ShapedType tensorTy) { auto et = tensorTy.getElementType(); dstream << " { \"type\" : "; - if (et.isa()) { + if (mlir::isa(et)) { // If use "et.print(dstream)", the output is !krnl.StringType. // The missing of quotation will fail the jason parser. // Use just "string" for brief @@ -106,7 +106,7 @@ class ONNXEntryPointLowering : public OpRewritePattern { } else { } dstream << "] "; - auto name = attr.cast().getValue().str(); + auto name = mlir::cast(attr).getValue().str(); dstream << ", \"name\" : \"" << name << "\""; }) .Default([&](Type type) { llvm_unreachable("input is not a tensor"); }); @@ -133,10 +133,8 @@ class ONNXEntryPointLowering : public OpRewritePattern { if (argAttrs) { DictionaryAttr dictAttrs = llvm::dyn_cast(argAttrs[i]); if (dictAttrs && dictAttrs.contains("onnx.name")) - inputName = dictAttrs.getNamed("onnx.name") - .value() - .getValue() - .cast(); + inputName = mlir::cast( + dictAttrs.getNamed("onnx.name").value().getValue()); } concatTypeString(inputs[i], inputName, dstream); comma = std::string(" , "); @@ -152,10 +150,8 @@ class ONNXEntryPointLowering : public OpRewritePattern { if (argAttrs) { DictionaryAttr dictAttrs = llvm::dyn_cast(resAttrs[i]); if (dictAttrs && dictAttrs.contains("onnx.name")) - outputName = dictAttrs.getNamed("onnx.name") - .value() - .getValue() - .cast(); + outputName = mlir::cast( + dictAttrs.getNamed("onnx.name").value().getValue()); } concatTypeString(outputs[i], outputName, dstream); comma = std::string(" , "); @@ -440,7 +436,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // Operations that are legal only if types are not tensors. target.addDynamicallyLegalOp([&](Operation *op) { return llvm::none_of(op->getOperandTypes(), - [](Type type) { return type.isa(); }); + [](Type type) { return mlir::isa(type); }); }); // Define patterns. diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp index f1076a24fa..5b6e561a78 100644 --- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp +++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp @@ -81,13 +81,13 @@ struct ONNXCategoryMapperOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Basic information. int64_t rank = memRefType.getShape().size(); - ShapedType inputType = X.getType().cast(); + ShapedType inputType = mlir::cast(X.getType()); Type elementType = inputType.getElementType(); // Insert an allocation and deallocation for the result of this operation. @@ -102,15 +102,17 @@ struct ONNXCategoryMapperOpLowering // Convert the cats type to MemRefType. Type convertedCatsInt64s = typeConverter->convertType(cats_int64s.getType()); - assert(convertedCatsInt64s && convertedCatsInt64s.isa() && + assert(convertedCatsInt64s && mlir::isa(convertedCatsInt64s) && "Failed to convert type to MemRefType"); - MemRefType catsInt64sInMemRefType = convertedCatsInt64s.cast(); + MemRefType catsInt64sInMemRefType = + mlir::cast(convertedCatsInt64s); Type convertedCatsStrings = typeConverter->convertType(cats_strings.getType()); - assert(convertedCatsStrings && convertedCatsStrings.isa() && + assert(convertedCatsStrings && + mlir::isa(convertedCatsStrings) && "Failed to convert type to MemRefType"); MemRefType catsStringsInMemRefType = - convertedCatsStrings.cast(); + mlir::cast(convertedCatsStrings); // Create loop invariant values. Value constantForCatsInt64s = create.krnl.constant( @@ -218,7 +220,7 @@ struct ONNXCategoryMapperOpLowering int32_t size = cats_int64s.size(); for (int32_t idx = 0; idx < size; ++idx) { Attribute elemAttr = getElemAttr(cats_int64s_ArrayAttr, idx); - int64_t key = elemAttr.cast().getInt(); + int64_t key = mlir::cast(elemAttr).getInt(); dict[key] = idx; } @@ -233,7 +235,7 @@ struct ONNXCategoryMapperOpLowering int32_t size = cats_strings.size(); for (int32_t idx = 0; idx < size; ++idx) { Attribute elemAttr = getElemAttr(cats_strings_ArrayAttr, idx); - StringRef key = elemAttr.cast().getValue(); + StringRef key = mlir::cast(elemAttr).getValue(); dict[key] = idx; } @@ -258,7 +260,7 @@ struct ONNXCategoryMapperOpLowering [&](IntegerType) { inputElem = createKrnl.load(memref, loopInd); }) .Case([&](krnl::StringType stringType) { ArrayRef shape = - memref.getType().cast().getShape(); + mlir::cast(memref.getType()).getShape(); SmallVector newShape; bool hasDynamicDim = false; for (uint64_t i = 0; i < shape.size(); i++) { diff --git a/src/Conversion/ONNXToKrnl/Math/CumSum.cpp b/src/Conversion/ONNXToKrnl/Math/CumSum.cpp index 19252cbff5..195bdf0285 100644 --- a/src/Conversion/ONNXToKrnl/Math/CumSum.cpp +++ b/src/Conversion/ONNXToKrnl/Math/CumSum.cpp @@ -99,9 +99,9 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Common information. Type elementType = memRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 31a565ce05..0574f1c7c8 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -43,12 +43,12 @@ Value emitPostProcessingFor(ConversionPatternRewriter &rewriter, Location loc, template static void CheckIfCustomScalarOpIsSupported(Type elementType) { Type actualElementType = MathBuilder::elementTypeWithVector(elementType); - if (actualElementType.isa()) { + if (mlir::isa(actualElementType)) { if constexpr (std::is_same, CustomScalarOp>::value) return; llvm_unreachable("this op does not support custom scalar for integers"); } - if (actualElementType.isa()) { + if (mlir::isa(actualElementType)) { if constexpr (std::is_same, CustomScalarOp>::value) return; llvm_unreachable("this op does not support custom scalar for floats"); @@ -1137,7 +1137,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // If the two input values are a string then we want to use the krnlStrnCmp. // However, if the input values are a float or an int we can simply use the // equal function. - if (inputElemType.isa()) { + if (mlir::isa(inputElemType)) { Value strlenRes = create.krnl.strlen(lhs); Value strncmpRes = create.krnl.strncmp(lhs, rhs, strlenRes); // Confirm the strncmp is indeed valid. strncmp returns a value of 0 if the @@ -1573,7 +1573,8 @@ static LogicalResult getPartiallyFlattenedSimdCode( loadedVals.emplace_back(flatOper); continue; } - MemRefType memRefType = flatOper.getType().dyn_cast(); + MemRefType memRefType = + mlir::dyn_cast(flatOper.getType()); assert(memRefType && "expected memref"); VectorType vecType = VectorType::get({VL}, memRefType.getElementType()); @@ -1583,7 +1584,7 @@ static LogicalResult getPartiallyFlattenedSimdCode( if (hasOneElement(flatOper)) { // Not flattened, with only 1 dims, just put zeros as needed. int64_t scalarRank = - flatOper.getType().dyn_cast().getRank(); + mlir::dyn_cast(flatOper.getType()).getRank(); for (int r = 0; r < scalarRank; ++r) scalarAccessFct.emplace_back(LiteralIndexExpr(0)); @@ -2048,7 +2049,8 @@ struct ONNXElementwiseUnaryOpLowering // If type is scalar or vector, there is no need to allocate a buffer. // Just call scalar computation and return the result. This is efficient // when elementwise ops are used as activations for ops like LSTM/GRU/RNN. - if (!X.getType().isa() && !X.getType().isa()) { + if (!mlir::isa(X.getType()) && + !mlir::isa(X.getType())) { SmallVector args; args.emplace_back(X); // Load the remaining (scalar) values. @@ -2057,8 +2059,8 @@ struct ONNXElementwiseUnaryOpLowering args.emplace_back(operands[i]); continue; } - assert(!operands[i].getType().isa() && - !operands[i].getType().isa() && + assert(!mlir::isa(operands[i].getType()) && + !mlir::isa(operands[i].getType()) && "unary expected scalar additional values"); args.emplace_back(operands[i]); } @@ -2073,9 +2075,9 @@ struct ONNXElementwiseUnaryOpLowering Type convertedType = this->typeConverter->convertType(outputTensorType); int64_t alignment = KrnlTypeConverter::getDefaultAllocAlignment(outputTensorType); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); int64_t outputRank = outputMemRefType.getRank(); Type elementType = outputMemRefType.getElementType(); @@ -2227,9 +2229,9 @@ struct ONNXElementwiseBinaryOpLowering Type convertedType = this->typeConverter->convertType(outputTensorType); int64_t alignment = KrnlTypeConverter::getDefaultAllocAlignment(outputTensorType); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); Type outputElementType = outputMemRefType.getElementType(); uint64_t outputRank = outputMemRefType.getRank(); @@ -2404,9 +2406,9 @@ struct ONNXElementwiseVariadicOpLowering Type convertedType = this->typeConverter->convertType(outputTensorType); int64_t alignment = KrnlTypeConverter::getDefaultAllocAlignment(outputTensorType); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); Type outputElementType = outputMemRefType.getElementType(); uint64_t outputRank = outputMemRefType.getRank(); @@ -2583,9 +2585,9 @@ struct ONNXWhereOpLowering : public ConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); uint64_t outputRank = outputMemRefType.getRank(); ONNXWhereOpAdaptor operandAdaptor(operands); diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index 494bd58fea..edf1621f8f 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -385,9 +385,9 @@ struct ONNXGemmOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = this->typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the output of this operation. Type elementType = outputMemRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp b/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp index 36e49a5b12..cb1db12d17 100644 --- a/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp @@ -28,7 +28,7 @@ static Value emitArgmax(ConversionPatternRewriter &rewriter, Location loc, create(rewriter, loc); IndexExprScope scope(create.krnl); - MemRefType memRefType = input.getType().cast(); + MemRefType memRefType = mlir::cast(input.getType()); Type indexType = rewriter.getIndexType(); int64_t rank = memRefType.getRank(); Value zero = create.math.constantIndex(0); @@ -94,9 +94,9 @@ struct ONNXHardmaxOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Type elementType = memRefType.getElementType(); Value zero = create.math.constantIndex(0); diff --git a/src/Conversion/ONNXToKrnl/Math/LRN.cpp b/src/Conversion/ONNXToKrnl/Math/LRN.cpp index ca609c63b3..8095e294e3 100644 --- a/src/Conversion/ONNXToKrnl/Math/LRN.cpp +++ b/src/Conversion/ONNXToKrnl/Math/LRN.cpp @@ -39,9 +39,9 @@ struct ONNXLRNOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); auto outputMemRefShape = outputMemRefType.getShape(); Type elementType = outputMemRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index 612f54004a..69910f70d7 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -465,9 +465,9 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the output of this operation. Type elementType = outputMemRefType.getElementType(); @@ -478,9 +478,9 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { Value zero = create.math.constant(elementType, 0); Value A(adaptor.getA()), B(adaptor.getB()); - int aRank = A.getType().cast().getShape().size(); - int bRank = B.getType().cast().getShape().size(); - int cRank = alloc.getType().cast().getShape().size(); + int aRank = mlir::cast(A.getType()).getShape().size(); + int bRank = mlir::cast(B.getType()).getShape().size(); + int cRank = mlir::cast(alloc.getType()).getShape().size(); if (enableTiling && aRank == 2 && bRank == 2) { // Optimized Matmul only when 2D and allowed to tile and unroll. assert(cRank == 2 && "expected IxK * KxJ = IxJ 2D result"); diff --git a/src/Conversion/ONNXToKrnl/Math/MatMulInteger.cpp b/src/Conversion/ONNXToKrnl/Math/MatMulInteger.cpp index afbe1295c9..948b9632dd 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMulInteger.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMulInteger.cpp @@ -58,7 +58,7 @@ struct ONNXMatMulIntegerOpLowering // Prepare input A. Value AInt32 = create.onnx.cast(A, resElementType); if (!isNoneValue(aZeroPoint)) { - auto aZeroPointType = aZeroPoint.getType().cast(); + auto aZeroPointType = mlir::cast(aZeroPoint.getType()); int64_t aZeroPointRank = aZeroPointType.getRank(); Value aZeroPointInt32 = create.onnx.cast(aZeroPoint, resElementType); // If broadcasting, e.g. A is [MxK], zeroPoint is [M], M != 1. diff --git a/src/Conversion/ONNXToKrnl/Math/RandomNormal.cpp b/src/Conversion/ONNXToKrnl/Math/RandomNormal.cpp index af1edf0513..d9f087d895 100644 --- a/src/Conversion/ONNXToKrnl/Math/RandomNormal.cpp +++ b/src/Conversion/ONNXToKrnl/Math/RandomNormal.cpp @@ -36,9 +36,9 @@ struct ONNXRandomNormalOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); ArrayRef outputMemRefShape = outputMemRefType.getShape(); size_t outputRank = outputMemRefShape.size(); diff --git a/src/Conversion/ONNXToKrnl/Math/RandomNormalLike.cpp b/src/Conversion/ONNXToKrnl/Math/RandomNormalLike.cpp index cf32425823..eadd773151 100644 --- a/src/Conversion/ONNXToKrnl/Math/RandomNormalLike.cpp +++ b/src/Conversion/ONNXToKrnl/Math/RandomNormalLike.cpp @@ -37,9 +37,9 @@ struct ONNXRandomNormalLikeOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); ArrayRef outputMemRefShape = outputMemRefType.getShape(); int outputRank = outputMemRefShape.size(); Type elementType = outputMemRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 7faf3cd63f..edca352250 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -376,12 +376,12 @@ struct ONNXReductionOpLowering : public OpConversionPattern { ////////////////////////////////////////////////////////////////////// // Handle type conversion. - MemRefType memRefInType = input.getType().cast(); + MemRefType memRefInType = mlir::cast(input.getType()); Type convertedOutType = this->typeConverter->convertType(*op->result_type_begin()); - assert(convertedOutType && convertedOutType.isa() && + assert(convertedOutType && mlir::isa(convertedOutType) && "Failed to convert type to MemRefType"); - MemRefType memRefOutType = convertedOutType.cast(); + MemRefType memRefOutType = mlir::cast(convertedOutType); int64_t inRank = memRefInType.getRank(); int64_t outRank = memRefOutType.getRank(); auto memRefOutShape = memRefOutType.getShape(); @@ -571,9 +571,10 @@ struct ONNXReductionOpLowering : public OpConversionPattern { RankedTensorType::get({inRank}, rewriter.getIntegerType(1)); // Convert the mask type to MemRefType. Type convertedMaskType = this->typeConverter->convertType(maskType); - assert(convertedMaskType && convertedMaskType.isa() && + assert(convertedMaskType && mlir::isa(convertedMaskType) && "Failed to convert type to MemRefType"); - MemRefType maskTypeInMemRefType = convertedMaskType.cast(); + MemRefType maskTypeInMemRefType = + mlir::cast(convertedMaskType); maskVal = create.mem.alignedAlloc(maskTypeInMemRefType); falseVal = create.math.constant(rewriter.getIntegerType(1), 0); trueVal = create.math.constant(rewriter.getIntegerType(1), 1); @@ -600,7 +601,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // Consider the case when axes[i] is negative // maskVal[axes[i] < 0 ? axes[i]+inRank: axes[i]] = 1 auto axesElementType = - axesVal.getType().cast().getElementType(); + mlir::cast(axesVal.getType()).getElementType(); auto dataDimConst = create.math.constant(axesElementType, inRank); Value zeroValue = create.math.constant(axesElementType, 0); if (!axisShape0.isLiteral()) { diff --git a/src/Conversion/ONNXToKrnl/Math/Softmax.cpp b/src/Conversion/ONNXToKrnl/Math/Softmax.cpp index 31c695745d..85fa78ad0a 100644 --- a/src/Conversion/ONNXToKrnl/Math/Softmax.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Softmax.cpp @@ -25,7 +25,7 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops, SmallVectorImpl &Lbs, SmallVectorImpl &Ubs, ValueRange outerIndices, Value input, Value alloc, Value zero, Value negInfinity, int64_t axis, bool coerced = true) { - int64_t rank = alloc.getType().cast().getRank(); + int64_t rank = mlir::cast(alloc.getType()).getRank(); ValueRange maxInits = ValueRange(negInfinity); // Compute the maximum value along axis. @@ -142,7 +142,7 @@ template <> void emitInstForSoftmax(ConversionPatternRewriter &rewriter, Operation *op, Location loc, Value alloc, Value input, Value zero, Value negInfinity, int64_t axis, bool enableParallel) { - int64_t rank = alloc.getType().cast().getRank(); + int64_t rank = mlir::cast(alloc.getType()).getRank(); MultiDialectBuilder create( rewriter, loc); @@ -208,7 +208,7 @@ template <> void emitInstForSoftmax(ConversionPatternRewriter &rewriter, Operation *op, Location loc, Value alloc, Value input, Value zero, Value negInfinity, int64_t axis, bool enableParallel) { - int64_t rank = alloc.getType().cast().getRank(); + int64_t rank = mlir::cast(alloc.getType()).getRank(); MultiDialectBuilder create( rewriter, loc); @@ -279,9 +279,9 @@ struct ONNXSoftmaxLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = this->typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); int64_t rank = memRefType.getRank(); int64_t axis = adaptor.getAxis(); diff --git a/src/Conversion/ONNXToKrnl/Math/TopK.cpp b/src/Conversion/ONNXToKrnl/Math/TopK.cpp index 765e55a5cf..1b937a9c43 100644 --- a/src/Conversion/ONNXToKrnl/Math/TopK.cpp +++ b/src/Conversion/ONNXToKrnl/Math/TopK.cpp @@ -35,9 +35,9 @@ struct ONNXTopKOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType resMemRefType = convertedType.cast(); + MemRefType resMemRefType = mlir::cast(convertedType); // Common types. Type i64Type = rewriter.getI64Type(); diff --git a/src/Conversion/ONNXToKrnl/Math/Trilu.cpp b/src/Conversion/ONNXToKrnl/Math/Trilu.cpp index 109cc19283..e60261f517 100644 --- a/src/Conversion/ONNXToKrnl/Math/Trilu.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Trilu.cpp @@ -41,9 +41,9 @@ struct ONNXTriluOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); int64_t rank = memRefType.getRank(); Type elementType = memRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index f1cab98d57..d79935ed0c 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -45,7 +45,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { auto inputOperand = operandAdaptor.getX(); auto filterOperand = operandAdaptor.getW(); auto biasOperand = operandAdaptor.getB(); - bool hasBias = !biasOperand.getType().isa(); + bool hasBias = !mlir::isa(biasOperand.getType()); int64_t groupNum = convOp.getGroup(); IndexExpr G = LiteralIndexExpr(groupNum); Value fZero = create.math.constant(memRefType.getElementType(), 0); @@ -245,7 +245,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { // Insert allocation for the result of this operation. Value alloc = allocForONNXOp( convOp, rewriter, typeConverter, shapeHelper)[0]; - MemRefType memRefType = alloc.getType().cast(); + MemRefType memRefType = mlir::cast(alloc.getType()); convUnoptimized(rewriter, convOp, adaptor, shapeHelper, memRefType, alloc); rewriter.replaceOp(op, alloc); diff --git a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp index 9f3990776b..297a649a9f 100644 --- a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp @@ -48,9 +48,9 @@ struct ONNXBatchNormalizationInferenceModeOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Value epsilon = create.math.constant( memRefType.getElementType(), adaptor.getEpsilon().convertToDouble()); @@ -170,9 +170,9 @@ struct ONNXInstanceNormalizationOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Type elementType = memRefType.getElementType(); Value epsilon = create.math.constant( elementType, adaptor.getEpsilon().convertToDouble()); @@ -343,7 +343,7 @@ LogicalResult generateGenericLayerNormOpONNXCode( const TypeConverter *const typeConverter) { MDBuilder create(rewriter, loc); Value X = lnOp.getX(); // Original value, not translated. - TensorType XType = X.getType().cast(); + TensorType XType = mlir::cast(X.getType()); Type elementType = XType.getElementType(); int64_t XRank = XType.getRank(); int64_t axis = getAxisInRange(lnOp.getAxis(), XRank); @@ -464,7 +464,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { Value operand, int64_t operandIndex, int64_t axis, int64_t XRank, IndexExpr &modFactor) const { DimsExpr &operandDims = shapeHelper.inputsDims[operandIndex]; - int64_t operandRank = operand.getType().cast().getRank(); + int64_t operandRank = mlir::cast(operand.getType()).getRank(); modFactor = LiteralIndexExpr(1); // X: X0 X1 X2 | X3 X4 X5 . @@ -596,7 +596,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { // Get info. Value X = adaptor.getX(); - MemRefType XMemRefType = X.getType().cast(); + MemRefType XMemRefType = mlir::cast(X.getType()); DimsExpr XDims = shapeHelper.inputsDims[0]; int64_t XRank = XMemRefType.getRank(); int64_t axis = getAxisInRange(lnOp.getAxis(), XRank); @@ -701,9 +701,9 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { /*output*/ Value &flatMemRef) const { // Convert input. Type convertedType = this->typeConverter->convertType(inputVal.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Allocate. memRef = create.mem.alignedAlloc(memRefType, inputDims); // Flatten (do not keep flatten dims at this time). @@ -768,7 +768,8 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { if constexpr (std::is_same::value) isTraditionalLayerNorm = true; // Vector type. - Type elementType = YMemRef.getType().cast().getElementType(); + Type elementType = + mlir::cast(YMemRef.getType()).getElementType(); VectorType vecType = VectorType::get({VL}, elementType); // Init the two reductions. Value init = create.math.constant(elementType, 0.0); @@ -889,7 +890,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { MDBuilder create(rewriter, loc); Operation *op = lnOp.getOperation(); Value XMemRef = adaptor.getX(); - MemRefType XMemRefType = XMemRef.getType().cast(); + MemRefType XMemRefType = mlir::cast(XMemRef.getType()); Type elementType = XMemRefType.getElementType(); int64_t XRank = XMemRefType.getRank(); int64_t axis = getAxisInRange(lnOp.getAxis(), XRank); @@ -906,7 +907,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { // Fully flatten scale input. int64_t scaleRank = - adaptor.getScale().getType().template cast().getRank(); + mlir::cast(adaptor.getScale().getType()).getRank(); DimsExpr scaleDims; for (int64_t i = XRank - scaleRank; i < XRank; ++i) scaleDims.emplace_back(shapeHelper.inputsDims[1][i]); @@ -919,7 +920,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { // Fully flatten bias input, if present. if (!isNoneValue(lnOp.getB())) { int64_t biasRank = - adaptor.getB().getType().template cast().getRank(); + mlir::cast(adaptor.getB().getType()).getRank(); DimsExpr biasDims; for (int64_t i = XRank - biasRank; i < XRank; ++i) biasDims.emplace_back(shapeHelper.inputsDims[2][i]); @@ -1048,28 +1049,28 @@ def layer_norm_simd2_v3(x, a, scale, b): y = np.zeros((a1, s2)) for i1 in range(0, a1, b1): # iterate over blocks of b1 values - + # MEAN(x), MEAN(x2) # iterate over a_block, s_block: parallel add r = np.zeros((b1, b2)) r_square = np.zeros((b1, b2)) - for i2 in range(0, s2, b2): # Unroll B1, SIMD by B2, + for i2 in range(0, s2, b2): # Unroll B1, SIMD by B2, xx = x[i1:i1+b1, i2:i2+b2] xxx = np.multiply(xx, xx) r = np.add(r, xx) r_square = np.add(r_square, xxx) - + # simd reduction; res B1 x 1 # 2 B1 div mean_b = np.sum(r, axis=1, keepdims=True) # SIMD reduction to (B1x1) values. - mean_b = np.divide(mean_b, s2) # (B2x1) values... so scalar is ok. + mean_b = np.divide(mean_b, s2) # (B2x1) values... so scalar is ok. mean_square_b = np.sum(r_square, axis=1, keepdims=True) # Same. - mean_square_b = np.divide(mean_square_b, s2) + mean_square_b = np.divide(mean_square_b, s2) # var = mean_square - mean**2; all compute here are on (B1x1): B1 mul, B1 add mean2_b = np.multiply(mean_b, mean_b) # B1 values, ok to do scalar var_b = np.subtract(mean_square_b, mean2_b) - + # ADD eps, sqrt, inverse # computations on B1x1, scalar ok: B1 add, B1 sqrt, B1 div var_eps_b = np.add(var_b, 1e-05) @@ -1079,7 +1080,7 @@ def layer_norm_simd2_v3(x, a, scale, b): # tot ops up to here (on B1x1): div: 3*B1, sqrt: B1, mul B1, add/sub 2 B1, sqrt B1: tot 8 B1 # SIMD on entire S2 size - for i2 in range(0, s2, b2): # Unroll B1, SIMD by B2, + for i2 in range(0, s2, b2): # Unroll B1, SIMD by B2, x_b = x[i1:i1+b1, i2:i2+b2] d_b = np.subtract(x_b, mean_b) # broadcast of mean_b of 1 -> s2 normalized_b = np.multiply(d_b, inv_std_dev_b) # broadcast of mean_b of 1 -> s2 diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index f2e8449298..c0ffc46a4a 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -61,7 +61,7 @@ std::vector getDilations(PoolOp poolOp) { ArrayAttr dilationsAttribute = poolOp.getDilationsAttr(); bool isDefaultDilations = true; for (auto dilation : dilationsAttribute.getValue()) { - int64_t dilationValue = dilation.cast().getInt(); + int64_t dilationValue = mlir::cast(dilation).getInt(); if (dilationValue > 1 && isDefaultDilations) isDefaultDilations = false; dilations.emplace_back(dilationValue); @@ -205,14 +205,14 @@ struct ONNXPoolOpLowering : public OpConversionPattern { // Type information about the input and result of this operation. Value inputOperand = adaptor.getX(); - auto inputShape = inputOperand.getType().cast().getShape(); + auto inputShape = mlir::cast(inputOperand.getType()).getShape(); // Convert the output type to MemRefType. Type convertedType = this->typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); ArrayRef outputShape = memRefType.getShape(); Type outputElementType = memRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index bd7918589b..99e9b0ace5 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -33,7 +33,7 @@ Value OnnxToKrnlBuilder::reshape( const Value input, const ArrayRef shapeDims) const { assert(!shapeDims.empty() && "Shape dimensions should not be empty"); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); MultiDialectBuilder create(b(), loc()); @@ -101,7 +101,7 @@ Value OnnxToKrnlBuilder::transpose(const Value input, shape.push_back(dim.isLiteral() ? dim.getLiteral() : ShapedType::kDynamic); // Create the "onnx.Transpose" operation. - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Value transposeRes = create.onnx.transpose(MemRefType::get(shape, inputType.getElementType()), input, b().getI64ArrayAttr(perm)); @@ -110,7 +110,7 @@ Value OnnxToKrnlBuilder::transpose(const Value input, } bool isScalarValue(Value value) { - ShapedType stype = value.getType().dyn_cast(); + ShapedType stype = mlir::dyn_cast(value.getType()); assert(stype && "expected shaped type"); return stype.getRank() == 0; } @@ -131,7 +131,7 @@ bool hasAllScalarValues(ValueRange values) { bool hasOneElement(Value value) { if (isScalarValue(value)) return true; - ShapedType type = value.getType().dyn_cast(); + ShapedType type = mlir::dyn_cast(value.getType()); assert(type && "expected shaped type"); for (int64_t s : type.getShape()) if (s != 1) @@ -143,7 +143,7 @@ bool hasOneElement(Value value) { bool hasOneElementInInnermostDims(Value value, int64_t innerDim) { if (isScalarValue(value)) return true; - ShapedType type = value.getType().dyn_cast(); + ShapedType type = mlir::dyn_cast(value.getType()); assert(type && "expected shaped type"); mlir::ArrayRef shape = type.getShape(); int64_t rank = type.getRank(); @@ -158,7 +158,8 @@ bool hasOneElementInInnermostDims(Value value, int64_t innerDim) { bool indicesAreNonNegativeConstants(Value indices) { DenseElementsAttr valueAttribute = krnl::getDenseElementAttributeFromKrnlValue(indices); - if (!valueAttribute || !valueAttribute.getElementType().isa()) + if (!valueAttribute || + !mlir::isa(valueAttribute.getElementType())) return false; return llvm::all_of(valueAttribute.getValues(), @@ -208,7 +209,7 @@ std::map getReductionMapping( // Dynamic dimension are supported. void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc, krnl::KrnlIterateOperandPack &pack, Value operand, int index) { - auto shape = operand.getType().cast().getShape(); + auto shape = mlir::cast(operand.getType()).getShape(); assert(shape[index] != -1 && "expected kDynamic, not -1"); if (shape[index] == ShapedType::kDynamic) { MultiDialectBuilder create(rewriter, loc); @@ -233,7 +234,8 @@ void defineLoops(ConversionPatternRewriter &rewriter, Location loc, Value getDimOrConstant(ConversionPatternRewriter &rewriter, Location loc, Value operand, int64_t axis, Type type) { MultiDialectBuilder create(rewriter, loc); - ArrayRef shape = operand.getType().cast().getShape(); + ArrayRef shape = + mlir::cast(operand.getType()).getShape(); assert(shape[axis] != -1 && "expected kDynamic, not -1"); return (shape[axis] == ShapedType::kDynamic) ? create.math.cast(type, create.mem.dim(operand, axis)) @@ -275,10 +277,10 @@ DenseElementsAttr getDenseElementAttrFromConstValue(mlir::Value value) { } if (auto globalOp = dyn_cast_or_null(definingOp)) { if (globalOp.getValue().has_value()) - return globalOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(globalOp.getValueAttr()); } else if (auto constOp = dyn_cast_or_null(definingOp)) { if (constOp.getValue().has_value()) - return constOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constOp.getValueAttr()); } return nullptr; } @@ -351,7 +353,7 @@ Value emitArgSort(ConversionPatternRewriter &rewriter, Location loc, create(rewriter, loc); IndexExprScope scope(create.krnl); - MemRefType inputMemRefType = input.getType().cast(); + MemRefType inputMemRefType = mlir::cast(input.getType()); Type indexType = rewriter.getIndexType(); int64_t rank = inputMemRefType.getRank(); assert(axis >= 0 && axis < rank && "axis is out of bound"); @@ -430,9 +432,9 @@ Value emitArgSort(ConversionPatternRewriter &rewriter, Location loc, Value getOptionalScalarValue(ConversionPatternRewriter &rewriter, Location loc, Value optionalScalar, Type elementType, double defaultValue) { MultiDialectBuilder create(rewriter, loc); - if (optionalScalar.getType().isa()) { + if (mlir::isa(optionalScalar.getType())) { return create.math.constant(elementType, defaultValue); - } else if (optionalScalar.getType().cast().getRank() == 0) { + } else if (mlir::cast(optionalScalar.getType()).getRank() == 0) { return create.krnl.load(optionalScalar, {}); } else { Value zero = create.math.constantIndex(0); @@ -446,7 +448,7 @@ Value getOptionalScalarValue(ConversionPatternRewriter &rewriter, Location loc, MemRefType convertTypeWithCustomONNXDataLayoutToMemRef(Type type) { // Get tensor rank, shape, and element type. - RankedTensorType tensorType = type.dyn_cast(); + RankedTensorType tensorType = mlir::dyn_cast(type); assert(tensorType && "expected only ranked shapes"); ArrayRef shape = tensorType.getShape(); int64_t rank = shape.size(); @@ -526,7 +528,7 @@ KrnlTypeConverter::KrnlTypeConverter() { addConversion([](TensorType tensorType) { assert(tensorType.hasRank() && "expected only ranked shapes"); - if (tensorType.getElementType().isa()) { + if (mlir::isa(tensorType.getElementType())) { Type elementType = krnl::StringType::get(tensorType.getContext()); return MemRefType::get(tensorType.getShape(), elementType); } @@ -543,7 +545,7 @@ KrnlTypeConverter::KrnlTypeConverter() { }); addConversion([](SeqType seqType) { - auto seqElementType = seqType.getElementType().cast(); + auto seqElementType = mlir::cast(seqType.getElementType()); Type elementType = seqElementType.getElementType(); Type seqElementConvertedType; if (seqElementType.hasRank()) { @@ -581,7 +583,7 @@ KrnlTypeConverter::KrnlTypeConverter() { int64_t KrnlTypeConverter::getDefaultAllocAlignment(Type type) { int64_t alignment = -1; - if (auto tensorType = type.dyn_cast()) { + if (auto tensorType = mlir::dyn_cast(type)) { // Accelerators may have special versions of TensorType. Call the // conversions of accelerators. for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) { @@ -601,7 +603,7 @@ bool hasNonIdentityLayout(Value val) { if (isNoneValue(val)) return false; // Expect a memref now. - MemRefType type = val.getType().dyn_cast(); + MemRefType type = mlir::dyn_cast(val.getType()); assert(type && "expected a memref type"); return hasNonIdentityLayout(type); } diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 82e10ba141..2ee3e11bf7 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -229,7 +229,7 @@ mlir::Value emitScalarOpFor(mlir::ConversionPatternRewriter &rewriter, mlir::Type actualElementType = MathBuilder::elementTypeWithVector(scalarOperands[0].getType()); // Perform int or float operation depending on the actual elementary type. - if (actualElementType.isa()) { + if (mlir::isa(actualElementType)) { // Generate the integer code only if the scalar integer op is non-void // (unsupported) and non-int (supported by custom sequence of ops). if constexpr (!(std::is_same, NotSuportedScalarOp>::value) && @@ -237,7 +237,7 @@ mlir::Value emitScalarOpFor(mlir::ConversionPatternRewriter &rewriter, return rewriter.create>( loc, elementType, scalarOperands, std::nullopt); llvm_unreachable("unsupported integer operation"); - } else if (actualElementType.isa()) { + } else if (mlir::isa(actualElementType)) { // Generate the floating point code only if the scalar integer op is // non-void (unsupported) and non-int (supported by custom sequence of ops). if constexpr (!(std::is_same, NotSuportedScalarOp>::value) && @@ -493,9 +493,9 @@ std::vector allocForONNXOp(mlir::Operation *op, mlir::Value output = op->getResults()[i]; // Convert the output type to MemRefType. mlir::Type convertedType = typeConverter->convertType(output.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - mlir::MemRefType memRefType = convertedType.cast(); + mlir::MemRefType memRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the result of this operation. mlir::Value alloc = diff --git a/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp b/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp index ebbc7aeb82..ab216bea57 100644 --- a/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp +++ b/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp @@ -170,8 +170,8 @@ static Value tryToUnflip( IndexExpr ss = ubs[1]; // spatial size. LiteralIndexExpr zeroIE(0), oneIE(1), twoIE(2), threeIE(3); - Value resMemRef = - create.mem.alignedAlloc(boundingBoxes.getType().cast(), ubs); + Value resMemRef = create.mem.alignedAlloc( + mlir::cast(boundingBoxes.getType()), ubs); ValueRange loopDef = create.krnl.defineLoops(2); create.krnl.iterateIE(loopDef, loopDef, {zeroIE, zeroIE}, {bs, ss}, @@ -225,9 +225,9 @@ struct ONNXNonMaxSuppressionOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Common information. Type elementType = memRefType.getElementType(); @@ -244,7 +244,7 @@ struct ONNXNonMaxSuppressionOpLowering Value maxOutputBoxPerClass = getOptionalScalarValue( rewriter, loc, adaptor.getMaxOutputBoxesPerClass(), i64Type, 0); // Score threshold. - Type scoreType = scores.getType().cast().getElementType(); + Type scoreType = mlir::cast(scores.getType()).getElementType(); Value scoreTH = getOptionalScalarValue( rewriter, loc, adaptor.getScoreThreshold(), scoreType, 0); // IOU threshold. @@ -477,50 +477,50 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // # corners. // y1_min, x1_min, y1_max, x1_max = box1 // y2_min, x2_min, y2_max, x2_max = box2 -// +// // area1 = (y1_max - y1_min) * (x1_max - x1_min) // area2 = (y2_max - y2_min) * (x2_max - x2_min) // else: // # The box data is supplied as [x_center, y_center, width, height]. // x1_center, y1_center, w1, h1 = box1 // x2_center, y2_center, w2, h2 = box2 -// +// // x1_min = x1_center - w1 / 2 // x1_max = x1_center + w1 / 2 // x2_min = x2_center - w2 / 2 // x2_max = x2_center + w2 / 2 -// +// // y1_min = y1_center - h1 / 2 // y1_max = y1_center + h1 / 2 // y2_min = y2_center - h2 / 2 // y2_max = y2_center + h2 / 2 -// +// // area1 = h1 * w1 // area2 = h2 * w2 -// +// // intersection_x_min = max(x1_min, x2_min) // intersection_y_min = max(y1_min, y2_min) // intersection_x_max = min(x1_max, x2_max) // intersection_y_max = min(y1_max, y2_max) // intersection_area = max(intersection_x_max - intersection_x_min, 0) * \ // max(intersection_y_max - intersection_y_min, 0) -// +// // union_area = area1 + area2 - intersection_area + 1e-8 // return intersection_area / union_area -// -// +// +// // ''' // boxes :: [num_batch, spatial_dimension, 4] // scores :: [num_batch, num_class, spatial_dimension] // ''' -// -// +// +// // def nms(boxes, scores, max_output_boxes_per_class, iou_threshold, // score_threshold, center_point_box=0): // batch_size = scores.shape[0] // class_size = scores.shape[1] // spatial_size = boxes.shape[1] -// +// // score_threshold = score_threshold[0] // iou_threshold = iou_threshold[0] // # Suppress by spatial dimension. @@ -536,7 +536,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // max_per_class_by_score = max(max_per_class_by_score, topk) // max_output_per_class = min( // max_output_per_class, max_per_class_by_score) -// +// // # Sort scores in the descending order and get the sorted indices. // # order = np.argsort(-scores, axis=2) // order = np.full(scores.shape, -1) @@ -552,10 +552,10 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // yOrd = order[b, c, k] // if (scores[b, c, xOrd] < scores[b, c, yOrd]): // tmp = order[b, c, i] -// order[b, c, i] = order[b, c, k] -// order[b, c, k] = tmp -// -// +// order[b, c, i] = order[b, c, k] +// order[b, c, k] = tmp +// +// // # Check if the coordinates are flipped. If so, flip them back. // if (center_point_box == 0): // new_boxes = np.empty(boxes.shape) @@ -574,7 +574,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // x1_max = tmp // new_boxes[b, s] = [y1_min, x1_min, y1_max, x1_max] // boxes = new_boxes -// +// // # Output: [num_selected_indices, 3] // # The selected index format is [batch_index, class_index, box_index]. // num_selected_indices = batch_size * max_output_per_class * class_size @@ -598,17 +598,17 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // # Removed, ignore. // if removed_indices[selected_box_index]: // continue -// +// // # Pick the bounding box with the largest score. // selected_box = boxes[b, selected_box_index, :] -// +// // # Store the index of the picked box to the output. // selected_indices[effective_num_selected_indices] = [b, c, selected_box_index] // // # Update counters. // effective_max_output_per_class += 1 // effective_num_selected_indices += 1 -// +// // # Remove boxes overlapped too much with the selected box, using // # IOU. // for o in range(spatial_size): @@ -618,15 +618,15 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // removed_indices[o] = True // else: // removed_indices[o] = removed_indices[o] -// +// // # Since we cannot suppress by IOU in advance, so remove redundant score // # now. // res = np.empty((effective_num_selected_indices, 3)) // for i in range(effective_num_selected_indices): // res[i] = selected_indices[i] -// return res -// -// +// return res +// +// // print("testing nonmaxsuppression_center_point_box_format") // center_point_box = 1 // boxes = np.array([[ @@ -645,7 +645,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold, center_point_box) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_flipped_coordinates") // boxes = np.array([[ // [1.0, 1.0, 0.0, 0.0], @@ -663,7 +663,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_identical_boxes") // boxes = np.array([[ // [0.0, 0.0, 1.0, 1.0], @@ -671,7 +671,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // [0.0, 0.0, 1.0, 1.0], // [0.0, 0.0, 1.0, 1.0], // [0.0, 0.0, 1.0, 1.0], -// +// // [0.0, 0.0, 1.0, 1.0], // [0.0, 0.0, 1.0, 1.0], // [0.0, 0.0, 1.0, 1.0], @@ -687,7 +687,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_limit_output_size") // boxes = np.array([[ // [0.0, 0.0, 1.0, 1.0], @@ -705,7 +705,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_single_box") // boxes = np.array([[ // [0.0, 0.0, 1.0, 1.0] @@ -718,7 +718,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_suppress_by_IOU") // boxes = np.array([[ // [0.0, 0.0, 1.0, 1.0], @@ -736,7 +736,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_suppress_by_IOU_and_scores") // boxes = np.array([[ // [0.0, 0.0, 1.0, 1.0], @@ -754,7 +754,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_two_batches") // boxes = np.array([[[0.0, 0.0, 1.0, 1.0], // [0.0, 0.1, 1.0, 1.1], @@ -778,7 +778,7 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// +// // print("testing nonmaxsuppression_two_classes") // boxes = np.array([[ // [0.0, 0.0, 1.0, 1.0], @@ -798,8 +798,8 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(RewritePatternSet &patterns, // out = nms(boxes, scores, max_output_boxes_per_class, // iou_threshold, score_threshold) // np.testing.assert_allclose(selected_indices, out) -// -// +// +// // # if __name__ == "__main__": // # main() // clang-format on diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 57e4ff314d..07151fbad7 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -44,7 +44,7 @@ struct ONNXQuantizeLinearOpLowering auto xMemRefType = dyn_cast(X.getType()); auto yMemRefType = dyn_cast( typeConverter->convertType(qlOp.getResult().getType())); - MemRefType yScaleMemRefType = YScale.getType().cast(); + MemRefType yScaleMemRefType = mlir::cast(YScale.getType()); // Types Type elementType = xMemRefType.getElementType(); diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index c6000a6b87..b90fe14696 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -80,11 +80,11 @@ getActivationPack(ONNXGRUOp *op) { // Forward activations. if (activationArrAttr.size() > 0) { activationForward.f.name = - activationArrAttr[0].cast().getValue(); + mlir::cast(activationArrAttr[0]).getValue(); } if (activationArrAttr.size() > 1) { activationForward.g.name = - activationArrAttr[1].cast().getValue(); + mlir::cast(activationArrAttr[1]).getValue(); } } @@ -93,11 +93,12 @@ getActivationPack(ONNXGRUOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 2; if (activationArrAttr.size() > startIndex) { activationReverse.f.name = - activationArrAttr[startIndex].cast().getValue(); + mlir::cast(activationArrAttr[startIndex]).getValue(); } if (activationArrAttr.size() > startIndex + 1) { activationReverse.g.name = - activationArrAttr[startIndex + 1].cast().getValue(); + mlir::cast(activationArrAttr[startIndex + 1]) + .getValue(); } } } @@ -108,10 +109,10 @@ getActivationPack(ONNXGRUOp *op) { if (direction == FORWARD || direction == BIDIRECTIONAL) { // Forward activations. if (activationArrAttr.size() > 0) { - activationForward.f.alpha = activationArrAttr[0].cast(); + activationForward.f.alpha = mlir::cast(activationArrAttr[0]); } if (activationArrAttr.size() > 1) { - activationForward.g.alpha = activationArrAttr[1].cast(); + activationForward.g.alpha = mlir::cast(activationArrAttr[1]); } } @@ -120,11 +121,11 @@ getActivationPack(ONNXGRUOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 2; if (activationArrAttr.size() > startIndex) { activationReverse.f.alpha = - activationArrAttr[startIndex].cast(); + mlir::cast(activationArrAttr[startIndex]); } if (activationArrAttr.size() > startIndex + 1) { activationReverse.g.alpha = - activationArrAttr[startIndex + 1].cast(); + mlir::cast(activationArrAttr[startIndex + 1]); } } } @@ -135,10 +136,10 @@ getActivationPack(ONNXGRUOp *op) { if (direction == FORWARD || direction == BIDIRECTIONAL) { // Forward activations. if (activationArrAttr.size() > 0) { - activationForward.f.beta = activationArrAttr[0].cast(); + activationForward.f.beta = mlir::cast(activationArrAttr[0]); } if (activationArrAttr.size() > 1) { - activationForward.g.beta = activationArrAttr[1].cast(); + activationForward.g.beta = mlir::cast(activationArrAttr[1]); } } @@ -147,11 +148,11 @@ getActivationPack(ONNXGRUOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 2; if (activationArrAttr.size() > startIndex) { activationReverse.f.beta = - activationArrAttr[startIndex].cast(); + mlir::cast(activationArrAttr[startIndex]); } if (activationArrAttr.size() > startIndex + 1) { activationReverse.g.beta = - activationArrAttr[startIndex + 1].cast(); + mlir::cast(activationArrAttr[startIndex + 1]); } } } @@ -179,8 +180,8 @@ getWeightPack( if (op->getLinearBeforeReset() == 0) linearBeforeReset = false; - ArrayRef wShape = W.getType().cast().getShape(); - Type elementType = W.getType().cast().getElementType(); + ArrayRef wShape = mlir::cast(W.getType()).getShape(); + Type elementType = mlir::cast(W.getType()).getElementType(); int64_t hiddenSize = wShape[1] / 3; int64_t inputSize = wShape[2]; @@ -287,8 +288,8 @@ std::tuple getBiasPack( create(rewriter, loc); // Split B. if (!isNoneValue(B)) { - ArrayRef bShape = B.getType().cast().getShape(); - Type elementType = B.getType().cast().getElementType(); + ArrayRef bShape = mlir::cast(B.getType()).getShape(); + Type elementType = mlir::cast(B.getType()).getElementType(); int64_t hiddenSize = bShape[1] / 6; // MemRef types. @@ -379,7 +380,7 @@ GruState allocAndInitializeStates( Value noneValue; initializeIntermediateStates(rewriter, loc, state.forwardHt, state.reverseHt, noneValue, noneValue, operandAdaptor.getInitialH(), noneValue, - operandAdaptor.getX().getType().cast().getElementType(), + mlir::cast(operandAdaptor.getX().getType()).getElementType(), direction, /*onlyHidden=*/true); // Obtain the value of 'linear_before_reset' attribute. @@ -409,17 +410,17 @@ void calculateState( MultiDialectBuilder create(rewriter, loc); - ArrayRef xtShape = Xt.getType().cast().getShape(); + ArrayRef xtShape = mlir::cast(Xt.getType()).getShape(); int64_t batchSize = xtShape[0]; // Get Ht. Value Ht = (isForward) ? state.forwardHt : state.reverseHt; - ArrayRef htShape = Ht.getType().cast().getShape(); + ArrayRef htShape = mlir::cast(Ht.getType()).getShape(); int64_t hiddenSize = htShape[1]; // Frequently used types. - MemRefType matrixType = Ht.getType().cast(); + MemRefType matrixType = mlir::cast(Ht.getType()); unsigned htRank = matrixType.getRank(); Type elementType = matrixType.getElementType(); MemRefType matrixAllGatesType = diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index f8744681c9..dceed2cb5e 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -46,8 +46,8 @@ getWeightPack( // direction StringRef direction = op->getDirection(); - ArrayRef wShape = W.getType().cast().getShape(); - Type elementType = W.getType().cast().getElementType(); + ArrayRef wShape = mlir::cast(W.getType()).getShape(); + Type elementType = mlir::cast(W.getType()).getElementType(); int64_t hiddenSize = wShape[1] / 4; int64_t inputSize = wShape[2]; @@ -127,8 +127,8 @@ std::tuple getBiasPack( // Split B. if (!isNoneValue(B)) { - ArrayRef bShape = B.getType().cast().getShape(); - Type elementType = B.getType().cast().getElementType(); + ArrayRef bShape = mlir::cast(B.getType()).getShape(); + Type elementType = mlir::cast(B.getType()).getElementType(); int64_t hiddenSize = bShape[1] / 8; // MemRef types. @@ -186,8 +186,8 @@ std::tuple getBiasPack( // Split P. if (!isNoneValue(P)) { - ArrayRef pShape = P.getType().cast().getShape(); - Type elementType = P.getType().cast().getElementType(); + ArrayRef pShape = mlir::cast(P.getType()).getShape(); + Type elementType = mlir::cast(P.getType()).getElementType(); int64_t hiddenSize = pShape[1] / 3; // MemRef types. @@ -282,7 +282,7 @@ LstmState allocAndInitializeStates( initializeIntermediateStates(rewriter, loc, state.forwardHt, state.reverseHt, state.forwardCt, state.reverseCt, operandAdaptor.getInitialH(), operandAdaptor.getInitialC(), - operandAdaptor.getX().getType().cast().getElementType(), + mlir::cast(operandAdaptor.getX().getType()).getElementType(), direction, /*onlyHidden=*/false); return state; } @@ -312,18 +312,18 @@ void calculateState create(rewriter, loc); - ArrayRef xtShape = Xt.getType().cast().getShape(); + ArrayRef xtShape = mlir::cast(Xt.getType()).getShape(); int64_t batchSize = xtShape[0]; // Get Ht, Ct. Value Ht = (isForward) ? state.forwardHt : state.reverseHt; Value Ct = (isForward) ? state.forwardCt : state.reverseCt; - ArrayRef htShape = Ht.getType().cast().getShape(); + ArrayRef htShape = mlir::cast(Ht.getType()).getShape(); int64_t hiddenSize = htShape[1]; // Frequently used types. - MemRefType matrixType = Ht.getType().cast(); + MemRefType matrixType = mlir::cast(Ht.getType()); Type elementType = matrixType.getElementType(); MemRefType matrixAllGatesType = MemRefType::get({batchSize, 4 * hiddenSize}, elementType); diff --git a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp index 1c43a064e1..3f9e5e3b00 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNN.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNN.cpp @@ -68,7 +68,7 @@ getActivationPack(ONNXRNNOp *op) { // Forward activations. if (activationArrAttr.size() > 0) { activationForward.f.name = - activationArrAttr[0].cast().getValue(); + mlir::cast(activationArrAttr[0]).getValue(); } } @@ -77,7 +77,7 @@ getActivationPack(ONNXRNNOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 1; if (activationArrAttr.size() > startIndex) { activationReverse.f.name = - activationArrAttr[startIndex].cast().getValue(); + mlir::cast(activationArrAttr[startIndex]).getValue(); } } } @@ -88,7 +88,7 @@ getActivationPack(ONNXRNNOp *op) { if (direction == FORWARD || direction == BIDIRECTIONAL) { // Forward activations. if (activationArrAttr.size() > 0) { - activationForward.f.alpha = activationArrAttr[0].cast(); + activationForward.f.alpha = mlir::cast(activationArrAttr[0]); } } @@ -97,7 +97,7 @@ getActivationPack(ONNXRNNOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 1; if (activationArrAttr.size() > startIndex) { activationReverse.f.alpha = - activationArrAttr[startIndex].cast(); + mlir::cast(activationArrAttr[startIndex]); } } } @@ -108,7 +108,7 @@ getActivationPack(ONNXRNNOp *op) { if (direction == FORWARD || direction == BIDIRECTIONAL) { // Forward activations. if (activationArrAttr.size() > 0) { - activationForward.f.beta = activationArrAttr[0].cast(); + activationForward.f.beta = mlir::cast(activationArrAttr[0]); } } @@ -117,7 +117,7 @@ getActivationPack(ONNXRNNOp *op) { unsigned int startIndex = (direction == REVERSE) ? 0 : 1; if (activationArrAttr.size() > startIndex) { activationReverse.f.beta = - activationArrAttr[startIndex].cast(); + mlir::cast(activationArrAttr[startIndex]); } } } @@ -140,8 +140,8 @@ getWeightPack( // direction StringRef direction = op->getDirection(); - ArrayRef wShape = W.getType().cast().getShape(); - Type elementType = W.getType().cast().getElementType(); + ArrayRef wShape = mlir::cast(W.getType()).getShape(); + Type elementType = mlir::cast(W.getType()).getElementType(); int64_t hiddenSize = wShape[1]; int64_t inputSize = wShape[2]; @@ -214,8 +214,8 @@ std::tuple getBiasPack( // Split B. if (!isNoneValue(B)) { - ArrayRef bShape = B.getType().cast().getShape(); - Type elementType = B.getType().cast().getElementType(); + ArrayRef bShape = mlir::cast(B.getType()).getShape(); + Type elementType = mlir::cast(B.getType()).getElementType(); int64_t hiddenSize = bShape[1] / 2; // MemRef types. @@ -297,7 +297,7 @@ RnnState allocAndInitializeStates( Value noneValue; initializeIntermediateStates(rewriter, loc, state.forwardHt, state.reverseHt, noneValue, noneValue, operandAdaptor.getInitialH(), noneValue, - operandAdaptor.getX().getType().cast().getElementType(), + mlir::cast(operandAdaptor.getX().getType()).getElementType(), direction, /*onlyHidden=*/true); return state; } @@ -326,7 +326,7 @@ void calculateState( // Get Ht. Value Ht = (isForward) ? state.forwardHt : state.reverseHt; - MemRefType matrixType = Ht.getType().cast(); + MemRefType matrixType = mlir::cast(Ht.getType()); unsigned htRank = matrixType.getRank(); // Do matrix multiplications. diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp index 8c680397a5..dab3a299f2 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -42,9 +42,9 @@ Value allocAllHidden(ConversionPatternRewriter &rewriter, Location loc, // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(output.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); alloc = create.mem.alignedAlloc(memRefType, dims); } else { @@ -62,7 +62,7 @@ Value allocIntermediateState( IndexExprScope scope(create.krnlIE); auto memRefType = MemRefType::get({/*batch_size=*/dimAt(X, 1), /*hidden_size=*/dimAt(R, 2)}, - X.getType().cast().getElementType()); + mlir::cast(X.getType()).getElementType()); SmallVector dims; // Get batch_size from X. dims.emplace_back(create.krnlIE.getShapeAsDim(X, 1)); @@ -167,9 +167,9 @@ Value allocHiddenOrCell(ConversionPatternRewriter &rewriter, Location loc, // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(output.getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); alloc = create.mem.alignedAlloc(memRefType, dims); } else { alloc = output; @@ -184,7 +184,7 @@ void initializeHiddenAndCell(ConversionPatternRewriter &rewriter, Location loc, MultiDialectBuilder create( rewriter, loc); Value zero = create.math.constant(elementType, 0); - unsigned htRank = ht.getType().cast().getRank(); + unsigned htRank = mlir::cast(ht.getType()).getRank(); Value iZero = create.math.constantIndex(0); SmallVector htLbs(htRank, iZero); SmallVector htUbs; @@ -222,7 +222,7 @@ void stateToOutputForHiddenOrCell(ConversionPatternRewriter &rewriter, Value numOfElements = getDynamicMemRefSize(rewriter, loc, val); create.krnl.memcpy(output, val, numOfElements); } else { // BIDIRECTIONAL - unsigned rank = forwardVal.getType().cast().getRank(); + unsigned rank = mlir::cast(forwardVal.getType()).getRank(); Value zero = create.math.constantIndex(0); Value one = create.math.constantIndex(1); SmallVector lbs(rank, zero); @@ -258,7 +258,7 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, int64_t batchSize = dimAt(X, 1); int64_t inputSize = dimAt(X, 2); - Type elementType = X.getType().cast().getElementType(); + Type elementType = mlir::cast(X.getType()).getElementType(); MemRefType sliceXType = MemRefType::get({batchSize, inputSize}, elementType); // Allocate a buffer diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp index 59c30062cb..98674670ac 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp @@ -34,7 +34,7 @@ struct ONNXSequenceAtOpLowering : public OpConversionPattern { IndexExprScope IEScope(&rewriter, loc); Type outputMemRefType = - input_sequence.getType().cast().getElementType(); + mlir::cast(input_sequence.getType()).getElementType(); auto dimSize = create.mem.dim(input_sequence, 0); SymbolIndexExpr boundIE(dimSize); diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceEmpty.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceEmpty.cpp index 12b3654be5..43addd357c 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceEmpty.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceEmpty.cpp @@ -32,9 +32,9 @@ struct ONNXSequenceEmptyOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); Value alloc = rewriter.create(loc, outputMemRefType, ValueRange()); diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp index 83f0829a78..fd7bcbf118 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp @@ -40,9 +40,8 @@ struct ONNXSequenceEraseOpLowering Value dimSize = create.mem.dim(input_sequence, 0); SymbolIndexExpr boundIE(dimSize); - MemRefType outputMemRefType = - typeConverter->convertType(seqOp.getResult().getType()) - .cast(); + MemRefType outputMemRefType = mlir::cast( + typeConverter->convertType(seqOp.getResult().getType())); SymbolIndexExpr outputBound = boundIE - 1; Value outputBoundVal = outputBound.getValue(); diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp index 19f0b39a5c..622c80fb5a 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp @@ -37,9 +37,9 @@ struct ONNXSequenceInsertOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(seqOp.getResult().getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); auto input_sequence = adaptor.getInputSequence(); auto dimSize = create.mem.dim(input_sequence, 0); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp index a4b8a390aa..592ba67b1a 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp @@ -63,15 +63,15 @@ struct ONNXArgMinMaxOpLowering : public OpConversionPattern { // Convert the reduced output type to MemRefType. Type convertedType = this->typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType reducedMemRefType = convertedType.cast(); + MemRefType reducedMemRefType = mlir::cast(convertedType); Type reducedElementType = reducedMemRefType.getElementType(); int64_t reducedRank = reducedMemRefType.getRank(); // data input Value data = adaptor.getData(); - MemRefType dataType = data.getType().cast(); + MemRefType dataType = mlir::cast(data.getType()); int64_t dataRank = dataType.getRank(); // axis & keepdims attribute diff --git a/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp b/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp index 139d408c2f..13546b0ebf 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp @@ -87,9 +87,9 @@ struct ONNXCompressOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the result of this operation. Value alloc = diff --git a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp index abed85ce44..1f0833d7f8 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp @@ -56,9 +56,9 @@ struct ONNXConcatOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type outputTensorType = *op->result_type_begin(); Type convertedType = typeConverter->convertType(outputTensorType); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); auto resultShape = outputMemRefType.getShape(); unsigned int rank = resultShape.size(); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp b/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp index 6fbc7b151f..24848eee8b 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp @@ -47,8 +47,8 @@ struct ONNXConcatShapeTransposeOpLowering unsigned numInputs = op->getNumOperands(); Value firstInput = adaptor.getInputs().front(); ArrayRef commonShape = - firstInput.getType().cast().getShape(); - // firstInput.getType().cast().getElementType(); + mlir::cast(firstInput.getType()).getShape(); + // mlir::cast(firstInput.getType()).getElementType(); uint64_t rank = commonShape.size(); int64_t axis = adaptor.getAxis(); @@ -99,7 +99,7 @@ struct ONNXConcatShapeTransposeOpLowering // Alloc and set value for ShapeOp output auto convertedShapeType = - typeConverter->convertType(outputShapeType).cast(); + mlir::cast(typeConverter->convertType(outputShapeType)); Value shapeAlloc = create.mem.alignedAlloc( convertedShapeType, shapeHelper.getOutputDims()); Type elementType = convertedShapeType.getElementType(); @@ -114,7 +114,8 @@ struct ONNXConcatShapeTransposeOpLowering DimsExpr outputTransposeDims = shapeHelper.getOutputDims(1); ArrayAttr permAttr = adaptor.getPermAttr(); Type t = op->getResultTypes()[1]; - auto outputTransposeType = typeConverter->convertType(t).cast(); + auto outputTransposeType = + mlir::cast(typeConverter->convertType(t)); Value alloc = create.mem.alignedAlloc(outputTransposeType, outputTransposeDims); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index dfdc35d910..58e3404dcf 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -33,21 +33,22 @@ struct ONNXConstantOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); // Emit the constant global in Krnl dialect. MultiDialectBuilder create(rewriter, loc); mlir::Attribute constValAttr = constantOp.getValue().value(); - if (memRefType.getElementType().isa()) { + if (mlir::isa(memRefType.getElementType())) { // If the onnx.ConstantOp has string type value attribute, // The element type of the value attribute of krnl.global op should be // "!krnl.string" instead of "!onnx.String". ShapedType constStrType = RankedTensorType::get( memRefType.getShape(), krnl::StringType::get(rewriter.getContext())); SmallVector constStrVector( - constValAttr.dyn_cast().getValues()); + mlir::dyn_cast(constValAttr) + .getValues()); ArrayRef constStrValues(constStrVector); constValAttr = mlir::DenseElementsAttr::get(constStrType, constStrValues); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp b/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp index 56cd316df6..2b6b98d30e 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp @@ -29,13 +29,13 @@ struct ONNXConstantOfShapeOpLowering Operation *op = constantOp.getOperation(); Location loc = ONNXLoc(op); - auto valueAttr = adaptor.getValue().value().cast(); + auto valueAttr = mlir::cast(adaptor.getValue().value()); // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Type elementType = memRefType.getElementType(); ArrayRef outputShape = memRefType.getShape(); size_t rank = outputShape.size(); @@ -65,13 +65,13 @@ struct ONNXConstantOfShapeOpLowering // Get the constant value from the attribute 'value'. Value constantVal; - if (elementType.isa()) { + if (mlir::isa(elementType)) { auto valueIt = valueAttr.getValues().begin(); - auto valueInt = (*valueIt++).cast().getInt(); + auto valueInt = mlir::cast(*valueIt++).getInt(); constantVal = create.math.constant(elementType, valueInt); - } else if (elementType.isa()) { + } else if (mlir::isa(elementType)) { auto valueIt = valueAttr.getValues().begin(); - auto valueFloat = (*valueIt++).cast().getValueAsDouble(); + auto valueFloat = mlir::cast(*valueIt++).getValueAsDouble(); constantVal = create.math.constant(elementType, valueFloat); } else llvm_unreachable("unsupported element type"); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp b/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp index 2c6d47c25b..21fc67f8da 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp @@ -37,9 +37,9 @@ struct ONNXDimOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); Type elementType = outputMemRefType.getElementType(); // Output is 1D memref of one element. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp b/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp index 5b68fba5e4..b75a94e5a6 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp @@ -41,9 +41,9 @@ struct ONNXExpandOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); int64_t outputRank = outputMemRefType.getRank(); // Insert an allocation and deallocation for the output of this operation. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Flatten.cpp b/src/Conversion/ONNXToKrnl/Tensor/Flatten.cpp index 22bf5d07b2..5638342ea2 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Flatten.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Flatten.cpp @@ -26,7 +26,7 @@ Value insertAllocForFlatten(MemRefType memRefType, Location loc, ConversionPatternRewriter &rewriter, Value input, int64_t axisValue) { MultiDialectBuilder create(rewriter, loc); memref::AllocOp alloc; - auto inputShape = input.getType().cast().getShape(); + auto inputShape = mlir::cast(input.getType()).getShape(); int64_t inputRank = inputShape.size(); SmallVector allocOperands; @@ -62,7 +62,7 @@ struct ONNXFlattenOpLowering : public OpConversionPattern { Location loc = ONNXLoc(op); Value input = adaptor.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = mlir::cast(input.getType()); auto inputShape = inputTy.getShape(); size_t inputRank = inputShape.size(); int64_t axisValue = flattenOp.getAxis(); @@ -72,9 +72,9 @@ struct ONNXFlattenOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert alloc and dealloc Value alloc = (hasAllConstantDimensions(outputMemRefType)) diff --git a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp index 27a0789dba..35cb8c4f7e 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp @@ -46,9 +46,9 @@ struct ONNXGatherOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the output of this operation. Value alloc = @@ -58,8 +58,8 @@ struct ONNXGatherOpLowering : public OpConversionPattern { Value data = adaptor.getData(); Value indices = adaptor.getIndices(); int64_t axisLit = adaptor.getAxis(); - int64_t dataRank = data.getType().cast().getRank(); - int64_t indicesRank = indices.getType().cast().getRank(); + int64_t dataRank = mlir::cast(data.getType()).getRank(); + int64_t indicesRank = mlir::cast(indices.getType()).getRank(); // Determine whether indices may be negative. bool indicesMayBeNegative = !indicesAreNonNegativeConstants(indices); diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp index 02b9e1b898..9e4db0a1ca 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp @@ -40,9 +40,9 @@ struct ONNXGatherElementsOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the result of this operation. Value output = @@ -52,8 +52,8 @@ struct ONNXGatherElementsOpLowering Value data = adaptor.getData(); Value indices = adaptor.getIndices(); int64_t axis = adaptor.getAxis(); - int64_t dataRank = data.getType().cast().getRank(); - int64_t indicesRank = indices.getType().cast().getRank(); + int64_t dataRank = mlir::cast(data.getType()).getRank(); + int64_t indicesRank = mlir::cast(indices.getType()).getRank(); int64_t outputRank = outputMemRefType.getShape().size(); assert(indicesRank == dataRank && "Input tensors must have the same rank"); assert(outputRank == dataRank && "Output rank not equal to data rank"); diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp index 13497d8cc2..69a38b2fce 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -67,10 +67,10 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { DimsExpr dataDims, indicesDims; create.krnlIE.getShapeAsDims(data, dataDims); create.krnlIE.getShapeAsDims(indices, indicesDims); - auto dataType = data.getType().cast(); + auto dataType = mlir::cast(data.getType()); int64_t dataRank = dataDims.size(); int64_t indicesRank = indicesDims.size(); - auto indicesType = indices.getType().cast(); + auto indicesType = mlir::cast(indices.getType()); ArrayRef indicesShape = indicesType.getShape(); int64_t indicesLastDim = indicesShape[indicesRank - 1]; // ToFix: Handle case in which indicesLastDim is kDynamic. @@ -80,7 +80,7 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); DimsExpr outputDims = shapeHelper.getOutputDims(); diff --git a/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp b/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp index 8bd326e07a..42311bddd9 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp @@ -87,12 +87,12 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { // Frequently used MemRefType. Value X = adaptor.getX(); - MemRefType xMemRefType = X.getType().cast(); + MemRefType xMemRefType = mlir::cast(X.getType()); // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType resMemRefType = convertedType.cast(); + MemRefType resMemRefType = mlir::cast(convertedType); int64_t xRank = xMemRefType.getRank(); // Frequently used element types. diff --git a/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp b/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp index 80e733c728..bf55e3636f 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp @@ -42,9 +42,9 @@ struct ONNXOneHotOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the output of this operation. Value alloc = diff --git a/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp b/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp index 7a8f84d10a..542b3c2fd8 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp @@ -45,9 +45,9 @@ struct ONNXPadOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType resMemRefType = convertedType.cast(); + MemRefType resMemRefType = mlir::cast(convertedType); Type resElementType = resMemRefType.getElementType(); // Insert an allocation and deallocation for the output of this operation. @@ -69,13 +69,13 @@ struct ONNXPadOpLowering : public OpConversionPattern { // This way is to avoid using `select` in computing indices as doing for // 'edge' and 'reflect' modes. Value cValue; - if (constantValue.getType().isa()) { + if (mlir::isa(constantValue.getType())) { // Default to 0 if constant_value is not specified. cValue = create.math.constant(resElementType, 0); } else { SmallVector loadIndices; MemRefType constantValueType = - constantValue.getType().dyn_cast(); + mlir::dyn_cast(constantValue.getType()); if (constantValueType.getElementType().isF32() && constantValueType.getRank() == 1) { // If the constant_value type is 1xf32 do krnl.load with an index of diff --git a/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp b/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp index 6e3df77d3f..df220ee303 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp @@ -40,7 +40,7 @@ struct ONNXPrintSignatureLowering // Discover the values to print, setting aside the last one. llvm::SmallVector printVal; for (Value oper : adaptor.getInput()) - if (!oper.getType().isa()) + if (!mlir::isa(oper.getType())) printVal.emplace_back(oper); int64_t printNum = printVal.size(); if (printNum == 0) { diff --git a/src/Conversion/ONNXToKrnl/Tensor/Range.cpp b/src/Conversion/ONNXToKrnl/Tensor/Range.cpp index 6b45c67198..b3bd44d8cc 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Range.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Range.cpp @@ -39,15 +39,15 @@ struct ONNXRangeOpLowering : public OpConversionPattern { Value limit = adaptor.getLimit(); Value delta = adaptor.getDelta(); - auto startShape = start.getType().cast().getShape(); - auto limitShape = limit.getType().cast().getShape(); - auto deltaShape = delta.getType().cast().getShape(); + auto startShape = mlir::cast(start.getType()).getShape(); + auto limitShape = mlir::cast(limit.getType()).getShape(); + auto deltaShape = mlir::cast(delta.getType()).getShape(); // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Type elementType = memRefType.getElementType(); // Insert an allocation and deallocation for the result of this operation. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp index 473f1f0938..9d9a1ff1bd 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp @@ -47,9 +47,9 @@ struct ONNXReshapeOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); LLVM_DEBUG(llvm::dbgs() << "memRefType: " << memRefType << "\n"); MultiDialectBuilder diff --git a/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp b/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp index 495cb53397..9527fba3c3 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp @@ -34,9 +34,9 @@ struct ONNXResizeOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); int64_t rank = memRefType.getShape().size(); // Check limitation imposed by implementation diff --git a/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp b/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp index e57503c4c5..64c0b5cd27 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp @@ -41,9 +41,9 @@ struct ONNXReverseSequenceOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); // Insert an allocation and deallocation for the output of this operation. Value alloc = diff --git a/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp b/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp index adfa73cf6d..3cc6a366e3 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp @@ -38,9 +38,9 @@ struct ONNXScatterElementsOpLowering Value updates = adaptor.getUpdates(); Value indices = adaptor.getIndices(); int64_t axis = adaptor.getAxis(); - int64_t dataRank = data.getType().cast().getRank(); - int64_t updatesRank = updates.getType().cast().getRank(); - int64_t indicesRank = indices.getType().cast().getRank(); + int64_t dataRank = mlir::cast(data.getType()).getRank(); + int64_t updatesRank = mlir::cast(updates.getType()).getRank(); + int64_t indicesRank = mlir::cast(indices.getType()).getRank(); assert(updatesRank == dataRank && indicesRank == dataRank && "All input tensors must have the same rank"); @@ -52,9 +52,9 @@ struct ONNXScatterElementsOpLowering // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); int64_t outputRank = outputMemRefType.getShape().size(); assert(outputRank == dataRank && "Output rank not equal to data rank"); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp index 64974d983e..b1fa591ace 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp @@ -33,9 +33,9 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern { Value data = adaptor.getData(); Value updates = adaptor.getUpdates(); Value indices = adaptor.getIndices(); - auto dataType = data.getType().cast(); - auto indicesType = indices.getType().cast(); - auto updatesType = updates.getType().cast(); + auto dataType = mlir::cast(data.getType()); + auto indicesType = mlir::cast(indices.getType()); + auto updatesType = mlir::cast(updates.getType()); int64_t dataRank = dataType.getRank(); int64_t updatesRank = updatesType.getRank(); int64_t indicesRank = indicesType.getRank(); @@ -45,9 +45,9 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); int64_t outputRank = outputMemRefType.getShape().size(); assert(outputRank == dataRank && "Output rank not equal to data rank"); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp index db2e4b92a8..95c01a1859 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp @@ -40,9 +40,9 @@ struct ONNXShapeOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); Type elementType = outputMemRefType.getElementType(); // TODO: if the dimensions are known at compile time diff --git a/src/Conversion/ONNXToKrnl/Tensor/Size.cpp b/src/Conversion/ONNXToKrnl/Tensor/Size.cpp index 9f743a66eb..ffd0629b2b 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Size.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Size.cpp @@ -32,14 +32,15 @@ struct ONNXSizeOpLowering : public OpConversionPattern { rewriter, loc); Value data = adaptor.getData(); - ArrayRef dataShape = data.getType().cast().getShape(); + ArrayRef dataShape = + mlir::cast(data.getType()).getShape(); Value resultOperand = sizeOp.getSize(); // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Value alloc = create.mem.alignedAlloc(resultOperand, memRefType); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp b/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp index f36cc18caa..659be67d41 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp @@ -37,9 +37,9 @@ struct ONNXSliceOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); int64_t outputRank = outputMemRefType.getShape().size(); // Insert an allocation and deallocation for the output of this operation. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp index 79d6000c28..e77d437a48 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp @@ -46,9 +46,9 @@ LogicalResult ONNXSplitOpLoweringCommon(OP_TYPE splitOp, OP_ADAPTOR adaptor, // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(splitOp.getOutputs()[i].getType()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); Value alloc = create.mem.alignedAlloc(memRefType, shapeHelper.getOutputDims(i)); allocs.emplace_back(alloc); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp index 4e9c078cbb..162910b8c3 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp @@ -31,7 +31,7 @@ LogicalResult ONNXSqueezeOpLoweringCommon(OP_TYPE squeezeOp, OP_ADAPTOR adaptor, // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); // Get shape. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp b/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp index cba4f5f5bb..e4dc340c97 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp @@ -28,7 +28,7 @@ Value insertAllocForTile(MemRefType memRefType, Location loc, Value repeatsOperand) { MultiDialectBuilder create( rewriter, loc); - auto inputShape = inputOperand.getType().cast().getShape(); + auto inputShape = mlir::cast(inputOperand.getType()).getShape(); size_t inputRank = inputShape.size(); auto outputShape = memRefType.getShape(); @@ -67,9 +67,9 @@ struct ONNXTileOpLowering : public OpConversionPattern { // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType memRefType = convertedType.cast(); + MemRefType memRefType = mlir::cast(convertedType); llvm::ArrayRef memRefShape = memRefType.getShape(); uint64_t outputRank = memRefShape.size(); @@ -122,15 +122,15 @@ struct ONNXTileOpLoweringAlternative : public OpConversionPattern { // get input operands, shapes, and rank Value input = adaptor.getInput(); - auto inputShape = input.getType().cast().getShape(); + auto inputShape = mlir::cast(input.getType()).getShape(); int64_t inputRank = inputShape.size(); Value repeats = adaptor.getRepeats(); // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); - MemRefType outputMemRefType = convertedType.cast(); + MemRefType outputMemRefType = mlir::cast(convertedType); auto outputMemRefShape = outputMemRefType.getShape(); int64_t outputRank = outputMemRefShape.size(); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp index d55d1dfd2c..94c6d7bd32 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp @@ -48,12 +48,12 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { auto permAttr = adaptor.getPerm(); // Input and output types. - MemRefType inMemRefType = data.getType().cast(); + MemRefType inMemRefType = mlir::cast(data.getType()); Type outConvertedType = typeConverter->convertType(*op->result_type_begin()); - assert(outConvertedType && outConvertedType.isa() && + assert(outConvertedType && mlir::isa(outConvertedType) && "Failed to convert type to MemRefType"); - MemRefType outMemRefType = outConvertedType.cast(); + MemRefType outMemRefType = mlir::cast(outConvertedType); // Get shape. ONNXTransposeOpShapeHelper shapeHelper(op, operands, &create.krnlIE); @@ -141,7 +141,7 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { void scalarTranspose(Operation *op, Value inputMemRef, Value outputMemRef, std::optional permAttr, MDBuilder *create, bool enableParallel) const { - uint64_t rank = outputMemRef.getType().cast().getRank(); + uint64_t rank = mlir::cast(outputMemRef.getType()).getRank(); ValueRange loopDef = create->krnl.defineLoops(rank); SmallVector lbs(rank, LiteralIndexExpr(0)); SmallVector ubs; @@ -172,7 +172,7 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { std::optional permAttr, MDBuilder *create, int numLastDims, bool enableParallel) const { Type i64Ty = create->math.getBuilder().getI64Type(); - MemRefType inMemRefType = inputMemRef.getType().cast(); + MemRefType inMemRefType = mlir::cast(inputMemRef.getType()); uint64_t rank = inMemRefType.getRank(); uint64_t outerRank = rank - numLastDims; diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp index 91c6ab594b..d43f14a7cb 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp @@ -30,7 +30,7 @@ Value emitArgUnique(ConversionPatternRewriter &rewriter, Location loc, MultiDialectBuilder create( rewriter, loc); IndexExprScope scope(create.krnl); - MemRefType inputMemRefType = input.getType().cast(); + MemRefType inputMemRefType = mlir::cast(input.getType()); int64_t rank = inputMemRefType.getRank(); assert(axis < rank && "axis is out of bound"); LiteralIndexExpr zeroIE(0), oneIE(1); @@ -105,7 +105,7 @@ struct ONNXUniqueOpLowering : public ConversionPattern { SmallVector XDims; create.krnlIE.getShapeAsDims(X, XDims); - Type elementType = X.getType().cast().getElementType(); + Type elementType = mlir::cast(X.getType()).getElementType(); int64_t rank = create.krnlIE.getShapedTypeRank(X); int64_t sorted = operandAdaptor.getSorted(); std::optional optionalAxis = uniqueOp.getAxis(); @@ -162,9 +162,8 @@ struct ONNXUniqueOpLowering : public ConversionPattern { Value outputY; if (hasStaticShape(uniqueOp.getY().getType())) { // This is a patch related to https://github.com/onnx/onnx/issues/6133 - MemRefType memrefType = - typeConverter->convertType(uniqueOp.getY().getType()) - .cast(); + MemRefType memrefType = mlir::cast( + typeConverter->convertType(uniqueOp.getY().getType())); outputY = create.mem.alignedAlloc(memrefType); } else if (axis < 0) { MemRefType memrefType = @@ -187,9 +186,8 @@ struct ONNXUniqueOpLowering : public ConversionPattern { isNoneValue(uniqueOp.getIndices()) ? emptyMemref : (hasStaticShape(indicesType) - ? create.mem.alignedAlloc( - typeConverter->convertType(indicesType) - .cast()) + ? create.mem.alignedAlloc(mlir::cast( + typeConverter->convertType(indicesType))) : create.mem.alignedAlloc(memrefType, outputIndexDims)); Type inverseIndicesType = uniqueOp.getInverseIndices().getType(); @@ -197,9 +195,8 @@ struct ONNXUniqueOpLowering : public ConversionPattern { isNoneValue(uniqueOp.getInverseIndices()) ? emptyMemref : (hasStaticShape(inverseIndicesType) - ? create.mem.alignedAlloc( - typeConverter->convertType(inverseIndicesType) - .cast()) + ? create.mem.alignedAlloc(mlir::cast( + typeConverter->convertType(inverseIndicesType))) : create.mem.alignedAlloc( memrefType, outputInverseIndexDims)); @@ -208,9 +205,8 @@ struct ONNXUniqueOpLowering : public ConversionPattern { isNoneValue(uniqueOp.getCounts()) ? emptyMemref : (hasStaticShape(countsType) - ? create.mem.alignedAlloc( - typeConverter->convertType(countsType) - .cast()) + ? create.mem.alignedAlloc(mlir::cast( + typeConverter->convertType(countsType))) : create.mem.alignedAlloc(memrefType, outputIndexDims)); // // Emit a Unique call to get the outputs diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp index c9ca83abad..70be437635 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp @@ -32,7 +32,7 @@ LogicalResult ONNXUnsqueezeOpLoweringCommon(OP_TYPE unsqueezeOp, // Convert the output type to MemRefType. Type convertedType = typeConverter->convertType(*op->result_type_begin()); - assert(convertedType && convertedType.isa() && + assert(convertedType && mlir::isa(convertedType) && "Failed to convert type to MemRefType"); // Get shape. diff --git a/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp b/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp index 85f1ec69ae..90f15db2ff 100644 --- a/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp +++ b/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp @@ -31,7 +31,7 @@ Value StablehloBuilder::constant(Type type, double val) const { Value constant = nullptr; // Could be a vector type; look at the element type. Type elementType = type; - VectorType vectorType = type.dyn_cast(); + VectorType vectorType = mlir::dyn_cast(type); if (vectorType) elementType = vectorType.getElementType(); TypeSwitch(elementType) @@ -133,7 +133,7 @@ Value OnnxToStablehloBuilder::reshape( const Value input, const ArrayRef shapeDims) const { assert(!shapeDims.empty() && "Shape dimensions should not be empty"); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); MultiDialectBuilder create( b(), loc()); @@ -190,7 +190,7 @@ Value OnnxToStablehloBuilder::transpose(const Value input, shape.push_back(dim.isLiteral() ? dim.getLiteral() : ShapedType::kDynamic); // Create the "onnx.Transpose" operation. - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Value transposeRes = create.onnx.transpose( RankedTensorType::get(shape, inputType.getElementType()), input, b().getI64ArrayAttr(perm)); @@ -213,19 +213,19 @@ ElementsAttr IndexExprBuilderForStablehlo::getConst(Value value) { } if (auto constOp = dyn_cast_or_null(definingOp)) { if (constOp.getValueAttr()) - return constOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constOp.getValueAttr()); } else if (auto constOp = dyn_cast_or_null(definingOp)) { if (constOp.getValue().has_value()) - return constOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constOp.getValueAttr()); } return nullptr; } Value IndexExprBuilderForStablehlo::getVal(Value intArrayVal, uint64_t i) { Type elemType = getElementType(intArrayVal.getType()); - if (!elemType.isa()) { + if (!mlir::isa(elemType)) { Type indexTensorType = RankedTensorType::get( - intArrayVal.getType().cast().getShape(), + mlir::cast(intArrayVal.getType()).getShape(), b().getIndexType()); intArrayVal = b().create(loc(), indexTensorType, intArrayVal); diff --git a/src/Conversion/ONNXToStablehlo/Math/Clip.cpp b/src/Conversion/ONNXToStablehlo/Math/Clip.cpp index b98360105e..8e50ee1522 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Clip.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Clip.cpp @@ -44,7 +44,7 @@ struct ONNXClipOpLoweringToStablehlo : public ConversionPattern { Type outputType = *op->result_type_begin(); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - ShapedType outputShapedType = outputType.cast(); + ShapedType outputShapedType = mlir::cast(outputType); Type elemType = outputShapedType.getElementType(); MathBuilder createMath(rewriter, loc); diff --git a/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp b/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp index b5b58f2bd6..5d564d3510 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Elementwise.cpp @@ -188,7 +188,7 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo Type resultType = *op->result_type_begin(); Value inp = operandAdaptor.getX(); - ShapedType inpType = inp.getType().dyn_cast_or_null(); + ShapedType inpType = mlir::dyn_cast_or_null(inp.getType()); if (inpType == nullptr) return failure(); Value alphaVal = getShapedFloat(loc, rewriter, alpha, inp); @@ -220,7 +220,7 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo double alpha = HardSigmoidOp.getAlpha().convertToDouble(); double beta = HardSigmoidOp.getBeta().convertToDouble(); Value inp = operandAdaptor.getX(); - ShapedType inpType = inp.getType().dyn_cast_or_null(); + ShapedType inpType = mlir::dyn_cast_or_null(inp.getType()); if (inpType == nullptr) return failure(); Value alphaVal = getShapedFloat(loc, rewriter, alpha, inp); @@ -247,7 +247,7 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo Location loc = op->getLoc(); ONNXReluOpAdaptor adaptor(operands, op->getAttrDictionary()); Value inp = adaptor.getX(); - ShapedType inpType = inp.getType().dyn_cast_or_null(); + ShapedType inpType = mlir::dyn_cast_or_null(inp.getType()); if (inpType == nullptr) return failure(); Type resultType = *op->result_type_begin(); @@ -271,7 +271,7 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo ONNXLeakyReluOpAdaptor adaptor(operands, op->getAttrDictionary()); Value inp = adaptor.getX(); llvm::APFloat alpha = adaptor.getAlpha(); - ShapedType inpType = inp.getType().dyn_cast_or_null(); + ShapedType inpType = mlir::dyn_cast_or_null(inp.getType()); if (inpType == nullptr) return failure(); Type resultType = *op->result_type_begin(); @@ -299,7 +299,7 @@ struct ONNXElementwiseUnaryOpLoweringToStablehlo ONNXCastOpAdaptor adaptor(operands, op->getAttrDictionary()); Value inp = adaptor.getInput(); Type elementToType = adaptor.getTo(); - ShapedType inpType = inp.getType().dyn_cast_or_null(); + ShapedType inpType = mlir::dyn_cast_or_null(inp.getType()); if (inpType == nullptr) return failure(); Value resultOp = diff --git a/src/Conversion/ONNXToStablehlo/Math/Gemm.cpp b/src/Conversion/ONNXToStablehlo/Math/Gemm.cpp index f7be589cd5..68d229f2b9 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Gemm.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Gemm.cpp @@ -51,7 +51,8 @@ struct ONNXGemmOpLoweringToStablehlo : public ConversionPattern { if (gemmOp.getTransB() == 1) transB = rewriter.create( loc, B, rewriter.getDenseI64ArrayAttr({1, 0})); - ShapedType resultType = gemmOp.getType().dyn_cast_or_null(); + ShapedType resultType = + mlir::dyn_cast_or_null(gemmOp.getType()); Value dot = rewriter.create( loc, gemmOp.getType(), transA, transB, nullptr); bool hasBias = shapeHelper.hasBias; @@ -143,7 +144,7 @@ struct ONNXGemmOpLoweringToStablehlo : public ConversionPattern { ONNXGemmOpShapeHelper shapeHelper(op, {}); shapeHelper.computeShapeAndAssertOnFailure(); - ShapedType outpType = gemmOp.getType().dyn_cast(); + ShapedType outpType = mlir::dyn_cast(gemmOp.getType()); if (outpType == nullptr) return failure(); Type elemType = outpType.getElementType(); diff --git a/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp b/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp index dc0c95ff95..e40057b5d3 100644 --- a/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp @@ -41,12 +41,12 @@ struct ONNXMatMulOpLoweringToStablehlo : public ConversionPattern { Type outputType = *op->result_type_begin(); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - ShapedType outputShapedType = outputType.cast(); + ShapedType outputShapedType = mlir::cast(outputType); Type elementType = outputShapedType.getElementType(); Value A(operandAdaptor.getA()), B(operandAdaptor.getB()); - auto aRank = A.getType().cast().getRank(); - auto bRank = B.getType().cast().getRank(); + auto aRank = mlir::cast(A.getType()).getRank(); + auto bRank = mlir::cast(B.getType()).getRank(); // Size all the arrays to padded length. int paddedRank = std::max(aRank, bRank); paddedRank = std::max(paddedRank, 2); @@ -87,7 +87,8 @@ struct ONNXMatMulOpLoweringToStablehlo : public ConversionPattern { const Value &operandToMatch, ArrayRef shapeInts, int64_t oneDPad) { Value broadcasted; - auto rank = operandToBroadcast.getType().cast().getRank(); + auto rank = + mlir::cast(operandToBroadcast.getType()).getRank(); RankedTensorType broadCastedType = RankedTensorType::get(shapeInts, elementType); SmallVector broadcastDimensions = diff --git a/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp b/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp index 4a476ad9ee..7a9054c68d 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Reduction.cpp @@ -120,7 +120,7 @@ llvm::SmallVector getDefinedAxes(Operation *op) { ArrayAttr axisAttrs = llvm::dyn_cast(op).getAxesAttr(); if (axisAttrs) { for (Attribute axisAttr : axisAttrs.getValue()) { - int64_t axis = axisAttr.cast().getInt(); + int64_t axis = mlir::cast(axisAttr).getInt(); definedAxes.push_back(axis); } } @@ -133,9 +133,8 @@ llvm::SmallVector getDefinedAxesFromConstAxes( // Assume it is verified that axes are known. Convert DenseElementsAttr to // ArrayAttr. if (!isNoneValue(axesValue) && getONNXConstantOp(axesValue)) { - mlir::ElementsAttr constAxes = getONNXConstantOp(axesValue) - .getValueAttr() - .dyn_cast_or_null(); + mlir::ElementsAttr constAxes = mlir::dyn_cast_or_null( + getONNXConstantOp(axesValue).getValueAttr()); for (mlir::IntegerAttr element : constAxes.getValues()) definedAxes.push_back(element.getInt()); return definedAxes; @@ -144,9 +143,9 @@ llvm::SmallVector getDefinedAxesFromConstAxes( return definedAxes; // Dynamic axes RankedTensorType inputType = - op->getOperands()[0].getType().dyn_cast(); + mlir::dyn_cast(op->getOperands()[0].getType()); RankedTensorType outputType = - op->getResultTypes()[0].dyn_cast(); + mlir::dyn_cast(op->getResultTypes()[0]); assert(inputType != nullptr && outputType != nullptr && "not implemented for dynamic axes when either input or output is not " "ranked"); @@ -308,7 +307,7 @@ SmallVector getReductionShape(ShapedType inputType, Value getReductionShapeValue(Location loc, PatternRewriter &rewriter, Value operand, llvm::SmallVector axes, bool keepDims) { - int64_t rank = operand.getType().cast().getRank(); + int64_t rank = mlir::cast(operand.getType()).getRank(); // Mark reduction axes. llvm::SmallVector isReductionAxis; for (int64_t i = 0; i < rank; ++i) { @@ -363,7 +362,8 @@ template Value createReduce(Location loc, Value operand, Value identity, SmallVector &reduceShape, llvm::SmallVector axes, PatternRewriter &rewriter, bool keepDims, ShapedType outputType) { - RankedTensorType operandType = operand.getType().cast(); + RankedTensorType operandType = + mlir::cast(operand.getType()); Type reduceResultType = RankedTensorType::get(reduceShape, operandType.getElementType()); stablehlo::ReduceOp reduce = rewriter.create(loc, @@ -411,11 +411,11 @@ struct ONNXReductionOpLoweringToStablehlo : public ConversionPattern { Location loc = op->getLoc(); Value input = operands[0]; // Type - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = mlir::cast(input.getType()); if (inputType == nullptr) return failure(); Type resultType = *op->result_type_begin(); - ShapedType outputType = resultType.cast(); + ShapedType outputType = mlir::cast(resultType); if (outputType == nullptr) return failure(); Type elemType = inputType.getElementType(); @@ -457,7 +457,7 @@ struct ONNXReductionOpLoweringToStablehlo : public ConversionPattern { loc, reduceResult, reduceFactorValue); } else { Value ones; - if (elemType.isa()) + if (mlir::isa(elemType)) ones = getShapedInt(loc, rewriter, 1, input); else ones = getShapedFloat(loc, rewriter, 1.0, input); diff --git a/src/Conversion/ONNXToStablehlo/NN/Conv.cpp b/src/Conversion/ONNXToStablehlo/NN/Conv.cpp index 55f142bc6d..f62db593fb 100644 --- a/src/Conversion/ONNXToStablehlo/NN/Conv.cpp +++ b/src/Conversion/ONNXToStablehlo/NN/Conv.cpp @@ -47,12 +47,12 @@ struct ONNXConvOpLoweringToStablehlo : public ConversionPattern { Value inputOperand = operandAdaptor.getX(); Value filterOperand = operandAdaptor.getW(); Value biasOperand = operandAdaptor.getB(); - bool hasBias = !biasOperand.getType().isa(); + bool hasBias = !mlir::isa(biasOperand.getType()); int64_t groupNum = convOp.getGroup(); assert(isRankedShapedType(inputOperand.getType()) && "Expected Ranked ShapedType"); - ShapedType inputType = inputOperand.getType().cast(); + ShapedType inputType = mlir::cast(inputOperand.getType()); llvm::ArrayRef inputShape = inputType.getShape(); Type outputType = *op->result_type_begin(); // Onnx Input is NCHW diff --git a/src/Conversion/ONNXToStablehlo/NN/ConvTranspose.cpp b/src/Conversion/ONNXToStablehlo/NN/ConvTranspose.cpp index 33360efd70..63ecb71d84 100644 --- a/src/Conversion/ONNXToStablehlo/NN/ConvTranspose.cpp +++ b/src/Conversion/ONNXToStablehlo/NN/ConvTranspose.cpp @@ -31,7 +31,7 @@ struct ONNXConvTransposeOpLoweringToStablehlo : public ConversionPattern { Value filterOperand, int64_t groupNum, int rank) const { assert(isRankedShapedType(filterOperand.getType()) && "Expected Ranked ShapedType"); - ShapedType filterType = filterOperand.getType().cast(); + ShapedType filterType = mlir::cast(filterOperand.getType()); assert(filterType.hasStaticShape() && "Expected static shape for filter"); ArrayRef filterShape = filterType.getShape(); Type elemType = filterType.getElementType(); @@ -82,12 +82,12 @@ struct ONNXConvTransposeOpLoweringToStablehlo : public ConversionPattern { Value inputOperand = operandAdaptor.getX(); Value filterOperand = operandAdaptor.getW(); Value biasOperand = operandAdaptor.getB(); - bool hasBias = !biasOperand.getType().isa(); + bool hasBias = !mlir::isa(biasOperand.getType()); int64_t groupNum = convOp.getGroup(); assert(isRankedShapedType(inputOperand.getType()) && "Expected Ranked ShapedType"); - ShapedType inputType = inputOperand.getType().cast(); + ShapedType inputType = mlir::cast(inputOperand.getType()); // Onnx Input is NCHW int64_t spatialOffset = 2; int64_t rank = inputType.getRank(); diff --git a/src/Conversion/ONNXToStablehlo/NN/Pooling.cpp b/src/Conversion/ONNXToStablehlo/NN/Pooling.cpp index 1ff55fe336..0f5784745b 100644 --- a/src/Conversion/ONNXToStablehlo/NN/Pooling.cpp +++ b/src/Conversion/ONNXToStablehlo/NN/Pooling.cpp @@ -27,17 +27,18 @@ static Value createInitialValueForPoolingOp( Location loc = op->getLoc(); if (isa(op)) { // returns negative infinity - return rewriter.create( - loc, rewriter.getFloatAttr(elemType, - APFloat::getInf(elemType.cast().getFloatSemantics(), - /*isNegative=*/true))); + return rewriter.create(loc, + rewriter.getFloatAttr(elemType, + APFloat::getInf(mlir::cast(elemType).getFloatSemantics(), + /*isNegative=*/true))); } if (isa(op)) { // returns negative infinity - return rewriter.create(loc, - rewriter.getFloatAttr(elemType, - APFloat::getZero(elemType.cast().getFloatSemantics(), - /*isNegative=*/false))); + return rewriter.create( + loc, rewriter.getFloatAttr(elemType, + APFloat::getZero( + mlir::cast(elemType).getFloatSemantics(), + /*isNegative=*/false))); } op->emitError("unimplemented lowering for onnx pooling op\n"); return nullptr; @@ -117,7 +118,7 @@ struct ONNXPoolOpLoweringToStablehlo : public ConversionPattern { // Type information about the input and result of this operation. Value inputOperand = operandAdaptor.getX(); RankedTensorType inputType = - inputOperand.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(inputOperand.getType()); if (inputType == nullptr) return failure(); llvm::ArrayRef inputShape = inputType.getShape(); diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp index bb6fbeee01..9dc5a243be 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp @@ -25,7 +25,7 @@ namespace onnx_mlir { Value getShapedZero( Location loc, ConversionPatternRewriter &rewriter, Value &inp) { - ShapedType inpType = inp.getType().cast(); + ShapedType inpType = mlir::cast(inp.getType()); Value broadcastedZero; if (inpType.hasStaticShape()) broadcastedZero = rewriter.create( @@ -45,19 +45,19 @@ llvm::SmallVector getBroadcastedOperands(Operation *op, ConversionPatternRewriter &rewriter, Location loc, int64_t outputRank) { llvm::SmallVector broadcastedOperands; Type outputType = *op->result_type_begin(); - assert(outputType.isa() && "output type is not shaped"); - ShapedType outputShapedType = outputType.cast(); + assert(mlir::isa(outputType) && "output type is not shaped"); + ShapedType outputShapedType = mlir::cast(outputType); Value resultExtents = mlir::hlo::computeNaryElementwiseBroadcastingResultExtents( loc, op->getOperands(), rewriter); for (Value operand : op->getOperands()) { RankedTensorType operandType = - operand.getType().dyn_cast(); + mlir::dyn_cast(operand.getType()); assert(operandType != nullptr && "operand type is not ranked"); SmallVector broadcastDimensions = llvm::to_vector<4>( llvm::seq(outputRank - operandType.getRank(), outputRank)); Type elementType = - operand.getType().dyn_cast().getElementType(); + mlir::dyn_cast(operand.getType()).getElementType(); RankedTensorType broadcastedOutputType = RankedTensorType::get(outputShapedType.getShape(), elementType); Value broadcast = rewriter.create(loc, @@ -72,19 +72,19 @@ llvm::SmallVector getBroadcastedOperands( llvm::SmallVector &operands, Type outputType, ConversionPatternRewriter &rewriter, Location loc, int64_t outputRank) { llvm::SmallVector broadcastedOperands; - assert(outputType.isa() && "output type is not shaped"); - ShapedType outputShapedType = outputType.cast(); + assert(mlir::isa(outputType) && "output type is not shaped"); + ShapedType outputShapedType = mlir::cast(outputType); Value resultExtents = mlir::hlo::computeNaryElementwiseBroadcastingResultExtents( loc, operands, rewriter); for (Value operand : operands) { RankedTensorType operandType = - operand.getType().dyn_cast(); + mlir::dyn_cast(operand.getType()); assert(operandType != nullptr && "operand type is not ranked"); SmallVector broadcastDimensions = llvm::to_vector<4>( llvm::seq(outputRank - operandType.getRank(), outputRank)); Type elementType = - operands[0].getType().dyn_cast().getElementType(); + mlir::dyn_cast(operands[0].getType()).getElementType(); RankedTensorType broadcastedOutputType = RankedTensorType::get(outputShapedType.getShape(), elementType); Value broadcast = rewriter.create(loc, @@ -98,11 +98,11 @@ llvm::SmallVector getBroadcastedOperands( ElementsAttr getElementAttributeFromConstValue(Value value) { auto definingOp = value.getDefiningOp(); if (auto constantOp = dyn_cast_or_null(definingOp)) { - return constantOp.getValue().dyn_cast(); + return mlir::dyn_cast(constantOp.getValue()); } else if (auto constantOp = dyn_cast_or_null(definingOp)) { if (constantOp.getValue().has_value()) - return constantOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constantOp.getValueAttr()); } return nullptr; } @@ -120,11 +120,11 @@ namespace { DenseElementsAttr getDenseElementAttrFromConstValue(mlir::Value value) { Operation *definingOp = value.getDefiningOp(); if (auto globalOp = dyn_cast_or_null(definingOp)) { - return globalOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(globalOp.getValueAttr()); } else if (auto constOp = dyn_cast_or_null(definingOp)) { if (constOp.getValue().has_value()) - return constOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constOp.getValueAttr()); } return nullptr; } diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp index 5a8c0b5666..c2b99f8e02 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp @@ -70,7 +70,7 @@ template Value getShapedFloat(Location loc, ConversionPatternRewriter &rewriter, const T &value, Value &inp) { Value broadcastedValue; - ShapedType inpType = inp.getType().cast(); + ShapedType inpType = mlir::cast(inp.getType()); if (inpType.hasStaticShape()) broadcastedValue = rewriter.create( loc, DenseElementsAttr::get(inpType, @@ -92,7 +92,7 @@ template Value getShapedInt(Location loc, ConversionPatternRewriter &rewriter, const T &value, Value &inp) { Value broadcastedValue; - ShapedType inpType = inp.getType().cast(); + ShapedType inpType = mlir::cast(inp.getType()); if (inpType.hasStaticShape()) broadcastedValue = rewriter.create( loc, DenseElementsAttr::get(inpType, diff --git a/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp b/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp index e3d4ee24a5..32186f88ac 100644 --- a/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp @@ -53,8 +53,8 @@ getWeightPack( // direction StringRef direction = op->getDirection(); - ArrayRef wShape = W.getType().cast().getShape(); - Type elementType = W.getType().cast().getElementType(); + ArrayRef wShape = mlir::cast(W.getType()).getShape(); + Type elementType = mlir::cast(W.getType()).getElementType(); int64_t hiddenSize = wShape[1] / 4; int64_t inputSize = wShape[2]; @@ -136,8 +136,8 @@ std::tuple getBiasPack( // Split B. if (!isNoneValue(B)) { - ArrayRef bShape = B.getType().cast().getShape(); - Type elementType = B.getType().cast().getElementType(); + ArrayRef bShape = mlir::cast(B.getType()).getShape(); + Type elementType = mlir::cast(B.getType()).getElementType(); int64_t hiddenSize = bShape[1] / 8; // MemRef types. @@ -195,8 +195,8 @@ std::tuple getBiasPack( // Split P. if (!isNoneValue(P)) { - ArrayRef pShape = P.getType().cast().getShape(); - Type elementType = P.getType().cast().getElementType(); + ArrayRef pShape = mlir::cast(P.getType()).getShape(); + Type elementType = mlir::cast(P.getType()).getElementType(); int64_t hiddenSize = pShape[1] / 3; // MemRef types. @@ -293,7 +293,8 @@ LstmState allocAndInitializeStates( initializeIntermediateStates(rewriter, loc, state.forwardHt, state.reverseHt, state.forwardCt, state.reverseCt, operandAdaptor.getInitialH(), operandAdaptor.getInitialC(), - operandAdaptor.getX().getType().cast().getElementType(), + mlir::cast(operandAdaptor.getX().getType()) + .getElementType(), direction, /*onlyHidden=*/false); return state; } @@ -315,18 +316,18 @@ void calculateState create(rewriter, loc); - ArrayRef xtShape = Xt.getType().cast().getShape(); + ArrayRef xtShape = mlir::cast(Xt.getType()).getShape(); int64_t batchSize = xtShape[0]; // Get Ht, Ct. Value Ht = (isForward) ? state.forwardHt : state.reverseHt; Value Ct = (isForward) ? state.forwardCt : state.reverseCt; - ArrayRef htShape = Ht.getType().cast().getShape(); + ArrayRef htShape = mlir::cast(Ht.getType()).getShape(); int64_t hiddenSize = htShape[1]; // Frequently used types. - RankedTensorType matrixType = Ht.getType().cast(); + RankedTensorType matrixType = mlir::cast(Ht.getType()); Type elementType = matrixType.getElementType(); RankedTensorType matrixAllGatesType = RankedTensorType::get({batchSize, 4 * hiddenSize}, elementType); @@ -452,10 +453,11 @@ void stateToOutput(ConversionPatternRewriter &rewriter, outputs.emplace_back(create.onnx.concat( op->getY().getType(), ValueRange(state.reverseAllH), 0)); } else { - auto outputShape = op->getY().getType().cast().getShape(); + auto outputShape = + mlir::cast(op->getY().getType()).getShape(); RankedTensorType singleDirectionType = RankedTensorType::get( {outputShape[0], 1, outputShape[2], outputShape[3]}, - op->getY().getType().cast().getElementType()); + mlir::cast(op->getY().getType()).getElementType()); outputs.emplace_back(create.onnx.concat(op->getY().getType(), {create.onnx.concat( singleDirectionType, ValueRange(state.forwardAllH), 0), diff --git a/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp b/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp index cb30905dcb..c61bca11e7 100644 --- a/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp @@ -28,7 +28,7 @@ Value allocAllHidden( MultiDialectBuilder create(rewriter, loc); RankedTensorType zeroType = RankedTensorType::get({dimAt(X, 0), 1, dimAt(X, 1), dimAt(R, 2)}, - X.getType().cast().getElementType()); + mlir::cast(X.getType()).getElementType()); DenseElementsAttr zeroAttr = DenseElementsAttr::get(zeroType, 0.0f); return create.onnx.constant(zeroAttr); } @@ -40,7 +40,7 @@ mlir::Value allocHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, RankedTensorType zeroType = RankedTensorType::get( {/*num_directions=*/dimAt(W, 0), /*batch_size=*/dimAt(X, 1), /*hidden_size=*/dimAt(R, 2)}, - X.getType().cast().getElementType()); + mlir::cast(X.getType()).getElementType()); DenseElementsAttr zeroAttr = DenseElementsAttr::get(zeroType, 0.0f); return create.onnx.constant(zeroAttr); } @@ -53,7 +53,7 @@ Value allocIntermediateState( RankedTensorType zeroType = RankedTensorType::get({/*batch_size=*/dimAt(X, 1), /*hidden_size=*/dimAt(R, 2)}, - X.getType().cast().getElementType()); + mlir::cast(X.getType()).getElementType()); DenseElementsAttr zeroAttr = DenseElementsAttr::get(zeroType, 0.0f); return create.onnx.constant(zeroAttr); } @@ -73,12 +73,12 @@ void initializeIntermediateStates(ConversionPatternRewriter &rewriter, Value boundVal = (direction == FORWARD || direction == BIDIRECTIONAL) ? forwardHt : reverseHt; - auto valShape = boundVal.getType().cast().getShape(); + auto valShape = mlir::cast(boundVal.getType()).getShape(); SmallVector sliceSizes = {1, valShape[0], valShape[1]}; SmallVector firstStartIndices = {zeroIndex, zeroIndex, zeroIndex}; SmallVector secondStartIndices = {oneIndex, zeroIndex, zeroIndex}; - RankedTensorType valType = boundVal.getType().cast(); + RankedTensorType valType = mlir::cast(boundVal.getType()); if (direction == FORWARD || direction == BIDIRECTIONAL) { if (!isNoneValue(initialH)) { forwardHt = create.stablehlo.dynamic_slice( @@ -129,16 +129,16 @@ void stateToOutputForHiddenOrCell(ConversionPatternRewriter &rewriter, output = val; } else { // BIDIRECTIONAL SmallVector bForwardValShape( - forwardVal.getType().cast().getShape()); + mlir::cast(forwardVal.getType()).getShape()); SmallVector bValShape( - forwardVal.getType().cast().getShape()); + mlir::cast(forwardVal.getType()).getShape()); SmallVector bReverseValShape( - reverseVal.getType().cast().getShape()); + mlir::cast(reverseVal.getType()).getShape()); bForwardValShape.insert(bForwardValShape.begin(), 1); bReverseValShape.insert(bReverseValShape.begin(), 1); bValShape.insert(bValShape.begin(), 2); Type valElementType = - forwardVal.getType().cast().getElementType(); + mlir::cast(forwardVal.getType()).getElementType(); Value zero = create.onnx.constantInt64({0}); Value bForwardVal = create.onnx.unsqueeze( RankedTensorType::get(bForwardValShape, valElementType), forwardVal, @@ -158,7 +158,7 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, MultiDialectBuilder create(rewriter, loc); int64_t batchSize = dimAt(X, 1); int64_t inputSize = dimAt(X, 2); - Type elementType = X.getType().cast().getElementType(); + Type elementType = mlir::cast(X.getType()).getElementType(); RankedTensorType squeezedXType = RankedTensorType::get({batchSize, inputSize}, elementType); SmallVector sliceSizes = {1, batchSize, inputSize}; diff --git a/src/Conversion/ONNXToStablehlo/Tensor/ArgMax.cpp b/src/Conversion/ONNXToStablehlo/Tensor/ArgMax.cpp index 5bed24a421..55dca2253a 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/ArgMax.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/ArgMax.cpp @@ -75,7 +75,7 @@ struct ONNXArgMaxOpLoweringToStablehlo : public ConversionPattern { Type outputType = *op->result_type_begin(); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - ShapedType outputShapedType = outputType.cast(); + ShapedType outputShapedType = mlir::cast(outputType); Type indexElementType = outputShapedType.getElementType(); Value indexInitValue = rewriter.create( loc, rewriter.getZeroAttr(indexElementType)); @@ -84,7 +84,7 @@ struct ONNXArgMaxOpLoweringToStablehlo : public ConversionPattern { Value data = operandAdaptor.getData(); assert(isRankedShapedType(data.getType()) && "data must be ranked Shaped Type"); - ShapedType dataType = data.getType().cast(); + ShapedType dataType = mlir::cast(data.getType()); Type elementType = dataType.getElementType(); int64_t dataRank = dataType.getRank(); @@ -96,10 +96,11 @@ struct ONNXArgMaxOpLoweringToStablehlo : public ConversionPattern { int64_t keepdims = argMaxOp.getKeepdims(); bool isKeepdims = (keepdims == 1) ? true : false; - Value initValue = rewriter.create(loc, - rewriter.getFloatAttr(elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), - /*isNegative=*/true))); + Value initValue = rewriter.create( + loc, rewriter.getFloatAttr(elementType, + APFloat::getInf( + mlir::cast(elementType).getFloatSemantics(), + /*isNegative=*/true))); RankedTensorType indexType = RankedTensorType::get(dataType.getShape(), indexElementType); diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Constant.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Constant.cpp index 226cfa1433..b8269ad5f8 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Constant.cpp @@ -34,8 +34,9 @@ struct ONNXConstantOpLoweringToStablehlo : public ConversionPattern { return constantOp.emitWarning("Only support dense values at this time"); assert(constantOp.getValue().has_value() && "Value is not set"); auto attr = constantOp.getValue().value(); - Value result = rewriter.create(loc, - ElementsAttrBuilder::toDenseElementsAttr(attr.cast())); + Value result = rewriter.create( + loc, ElementsAttrBuilder::toDenseElementsAttr( + mlir::cast(attr))); rewriter.replaceOp(op, result); return success(); } diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp index 2e40c2ade6..4557959592 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -33,10 +33,10 @@ struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { // Check that axisLit is a valid dimension index Value tensorArg = operands[0]; - assert(tensorArg.getType().isa() && + assert(mlir::isa(tensorArg.getType()) && "Expected ranked tensor type"); - int64_t rank = tensorArg.getType().cast().getRank(); + int64_t rank = mlir::cast(tensorArg.getType()).getRank(); assert((axisLit >= 0 && axisLit < rank) && "Axis must be in the range [0, input tensor rank - 1]"); @@ -45,7 +45,7 @@ struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { Value dimValue = rewriter.create(loc, inputShape, axisLit); Type dimType = dimOp.getDim().getType(); - Type indexValueType = dimType.cast().getElementType(); + Type indexValueType = mlir::cast(dimType).getElementType(); Value castedIndex = rewriter.create(loc, indexValueType, dimValue); Value indexTensor = rewriter.create( diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp index 95e333c9f1..2c928cc14a 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Expand.cpp @@ -47,12 +47,12 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern { Type outputType = *op->result_type_begin(); assert(isRankedShapedType(inputType) && "Expected Ranked ShapedType"); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - ShapedType outputShapedType = outputType.cast(); + ShapedType outputShapedType = mlir::cast(outputType); Type elementType = outputShapedType.getElementType(); int64_t outputRank = outputShapedType.getRank(); Value ones; - if (elementType.isa()) + if (mlir::isa(elementType)) ones = rewriter.create( loc, rewriter.getIntegerAttr(elementType, 1)); else @@ -69,7 +69,7 @@ struct ONNXExpandOpLoweringToStablehlo : public ConversionPattern { broadcastedOnes = rewriter.create( loc, broadcastedType, ones, rewriter.getDenseI64ArrayAttr({})); } else { - ShapedType shapeType = shape.getType().cast(); + ShapedType shapeType = mlir::cast(shape.getType()); assert(shapeType.getRank() == 1 && shapeType.hasStaticShape() && "expected 1D statically shaped shape tensor"); int64_t shapeRank = shapeType.getShape()[0]; diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Flatten.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Flatten.cpp index bf98ee332f..157b44f5c4 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Flatten.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Flatten.cpp @@ -35,7 +35,7 @@ struct ONNXFlattenOpLoweringToStablehlo : public ConversionPattern { Value input = operandAdaptor.getInput(); assert(isRankedShapedType(input.getType()) && "Expected Ranked ShapedType"); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); int64_t rank = inputType.getRank(); int64_t axis = flattenOp.getAxis(); assert(axis >= -rank && axis <= rank - 1); diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp index 066adc3cab..fed53e65dd 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp @@ -48,9 +48,9 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern { Value indices = operandAdaptor.getIndices(); int64_t axisLit = gatherOp.getAxis(); - ShapedType inputType = data.getType().cast(); + ShapedType inputType = mlir::cast(data.getType()); int64_t dataRank = inputType.getRank(); - ShapedType indicesType = indices.getType().cast(); + ShapedType indicesType = mlir::cast(indices.getType()); // Negative value means counting dimensions from the back. axisLit = axisLit < 0 ? axisLit + dataRank : axisLit; diff --git a/src/Conversion/ONNXToStablehlo/Tensor/GatherElements.cpp b/src/Conversion/ONNXToStablehlo/Tensor/GatherElements.cpp index 4140a445d3..4fa3bdb68c 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/GatherElements.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/GatherElements.cpp @@ -46,9 +46,9 @@ struct ONNXGatherElementsOpLoweringToStablehlo : public ConversionPattern { Value indices = operandAdaptor.getIndices(); int64_t axisLit = gatherOp.getAxis(); - ShapedType inputType = data.getType().cast(); + ShapedType inputType = mlir::cast(data.getType()); int64_t rank = inputType.getRank(); // indices has the same rank - ShapedType indicesType = indices.getType().cast(); + ShapedType indicesType = mlir::cast(indices.getType()); Type indexElemType = indicesType.getElementType(); // Negative value means counting dimensions from the back. axisLit = axisLit < 0 ? axisLit + rank : axisLit; diff --git a/src/Conversion/ONNXToStablehlo/Tensor/OneHot.cpp b/src/Conversion/ONNXToStablehlo/Tensor/OneHot.cpp index 3586147a41..1448749c4f 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/OneHot.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/OneHot.cpp @@ -47,7 +47,7 @@ struct ONNXOneHotOpLoweringToStablehlo int64_t axis = shapeHelper.axis; RankedTensorType indicesType = - indices.getType().dyn_cast(); + mlir::dyn_cast(indices.getType()); if (!indicesType || !indicesType.hasStaticShape()) return failure(); ArrayRef indicesShape = indicesType.getShape(); @@ -80,7 +80,8 @@ struct ONNXOneHotOpLoweringToStablehlo Value broadcastZero = rewriter.create( loc, indexType, zero, rewriter.getDenseI64ArrayAttr({})); Value broadcastDepth; - int64_t depthRank = depthValue.getType().cast().getRank(); + int64_t depthRank = + mlir::cast(depthValue.getType()).getRank(); if (depthRank == 1) broadcastDepth = rewriter.create( loc, indexType, depthValue, rewriter.getDenseI64ArrayAttr({0})); @@ -95,7 +96,7 @@ struct ONNXOneHotOpLoweringToStablehlo loc, indexType, compareGeZero, broadcastIndices, positiveIndices); Value compare = rewriter.create( loc, normalizedIndices, iota, stablehlo::ComparisonDirection::EQ); - Type valueType = values.getType().cast().getElementType(); + Type valueType = mlir::cast(values.getType()).getElementType(); Value offValue = rewriter.create(loc, RankedTensorType::get({1}, valueType), values, DenseI64ArrayAttr::get(context, ArrayRef{0}), diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Pad.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Pad.cpp index 14b72a7f92..72f9bff8f8 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Pad.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Pad.cpp @@ -40,7 +40,7 @@ struct ONNXPadOpLoweringToStablehlo : public ConversionPattern { if (!padMode.equals_insensitive("constant")) return failure(); assert(isRankedShapedType(data.getType()) && "Expected Ranked ShapedType"); - ShapedType inputType = data.getType().cast(); + ShapedType inputType = mlir::cast(data.getType()); Type elemType = inputType.getElementType(); int64_t rank = inputType.getRank(); @@ -52,7 +52,7 @@ struct ONNXPadOpLoweringToStablehlo : public ConversionPattern { rewriter.getZeroAttr(elemType))); } else { // constantValue might be 1D tensor, reshape it to scalar - ShapedType constantType = constantValue.getType().cast(); + ShapedType constantType = mlir::cast(constantValue.getType()); if (constantType.getRank() != 0) constantValue = rewriter.create( loc, RankedTensorType::get({}, elemType), constantValue); diff --git a/src/Conversion/ONNXToStablehlo/Tensor/ScatterND.cpp b/src/Conversion/ONNXToStablehlo/Tensor/ScatterND.cpp index e69a1b1350..59415cb779 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/ScatterND.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/ScatterND.cpp @@ -36,8 +36,8 @@ struct ONNXScatterNDOpLoweringToStablehlo Value data = adaptor.getData(); Value updates = adaptor.getUpdates(); Value indices = adaptor.getIndices(); - auto dataType = data.getType().cast(); - auto indicesType = indices.getType().cast(); + auto dataType = mlir::cast(data.getType()); + auto indicesType = mlir::cast(indices.getType()); int64_t dataRank = dataType.getRank(); int64_t indicesRank = indicesType.getRank(); if (indicesType.isDynamicDim(indicesRank - 1)) @@ -50,7 +50,7 @@ struct ONNXScatterNDOpLoweringToStablehlo Type outputType = *op->result_type_begin(); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - ShapedType outputShapedType = outputType.cast(); + ShapedType outputShapedType = mlir::cast(outputType); int64_t outputRank = outputShapedType.getRank(); assert(outputRank == dataRank && "Output rank not equal to data rank"); auto scatter_dimension_numbers = diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Shape.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Shape.cpp index 7e0955d8f2..a86582bc10 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Shape.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Shape.cpp @@ -37,8 +37,8 @@ struct ONNXShapeOpLoweringToStablehlo : public ConversionPattern { shapeHelper.computeShapeAndAssertOnFailure(); Type outputType = *op->result_type_begin(); - assert(outputType.isa() && "Expected ShapedType"); - ShapedType outputShapedType = outputType.cast(); + assert(mlir::isa(outputType) && "Expected ShapedType"); + ShapedType outputShapedType = mlir::cast(outputType); Type elementType = outputShapedType.getElementType(); Type resultOutputType = RankedTensorType::get( shapeHelper.getOutputDims(0)[0].getLiteral(), elementType); diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp index 56f880e325..6da4e17f78 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp @@ -44,7 +44,7 @@ struct ONNXSliceOpLoweringToStablehlo : public ConversionPattern { assert(isRankedShapedType(data.getType()) && "data must be ranked Shaped Type"); - ShapedType dataType = data.getType().cast(); + ShapedType dataType = mlir::cast(data.getType()); int64_t rank = dataType.getRank(); Type indexElementType = rewriter.getI64Type(); Value zero = rewriter.create(loc, @@ -59,7 +59,7 @@ struct ONNXSliceOpLoweringToStablehlo : public ConversionPattern { SmallVector axesIntLitToIdx(rank, -1); SmallVector indices; - if (axes.getType().isa()) { + if (mlir::isa(axes.getType())) { // If `axes` are omitted, they are set to `[0, ..., nDim-1]`." for (int64_t i = 0; i < rank; ++i) axesIntLitToIdx[i] = i; @@ -67,7 +67,7 @@ struct ONNXSliceOpLoweringToStablehlo : public ConversionPattern { // If `axes` are constants, read them." int64_t idx = 0; for (IntegerAttr value : valueAttribute.getValues()) { - int64_t axis = value.cast().getInt(); + int64_t axis = mlir::cast(value).getInt(); if (axis < 0) axis += rank; assert((axis >= 0 && axis < (int64_t)rank) && diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Split.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Split.cpp index 8eab7eca59..e715f1750e 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Split.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Split.cpp @@ -36,7 +36,7 @@ struct ONNXSplitOpLoweringToStablehlo : public ConversionPattern { Value split = splitOp.getSplit(); assert(isRankedShapedType(input.getType()) && "data must be ranked Shaped Type"); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); MLIRContext *context = op->getContext(); Location loc = op->getLoc(); uint64_t rank = inputType.getRank(); @@ -54,10 +54,10 @@ struct ONNXSplitOpLoweringToStablehlo : public ConversionPattern { SmallVector splitSizes; if (auto splitAttr = getElementAttributeFromONNXValue(split)) { for (IntegerAttr value : splitAttr.getValues()) { - int64_t splitSize = value.cast().getInt(); + int64_t splitSize = mlir::cast(value).getInt(); splitSizes.push_back(splitSize); } - } else if (split.getType().template isa()) { + } else if (mlir::isa(split.getType())) { assert(!ShapedType::isDynamic(inputDimSize) && "input dim size can't be dynamic"); int64_t sliceSize = inputDimSize / outputNum; diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Squeeze.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Squeeze.cpp index 5cc64cd40e..6415d088a1 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Squeeze.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Squeeze.cpp @@ -37,7 +37,7 @@ struct ONNXSqueezeOpLoweringToStablehlo : public ConversionPattern { Value axes = squeezeOp.getAxes(); assert(isRankedShapedType(data.getType()) && "data must be ranked Shaped Type"); - ShapedType dataType = data.getType().cast(); + ShapedType dataType = mlir::cast(data.getType()); int64_t rank = dataType.getRank(); // Shape helper is unused @@ -48,7 +48,7 @@ struct ONNXSqueezeOpLoweringToStablehlo : public ConversionPattern { SmallVector axesList; if (ElementsAttr axesAttr = getElementAttributeFromONNXValue(axes)) { for (IntegerAttr value : axesAttr.getValues()) { - int64_t axis = value.cast().getInt(); + int64_t axis = mlir::cast(value).getInt(); if (axis < 0) axis += rank; axesList.push_back(axis); diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Tile.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Tile.cpp index b5a7fe8e42..f0694c13ed 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Tile.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Tile.cpp @@ -45,7 +45,7 @@ struct ONNXTileOpLoweringToStablehlo : public ConversionPattern { Value input = tileOp.getInput(); Value multiples = tileOp.getRepeats(); assert(isRankedShapedType(input.getType()) && "Expected Ranked ShapedType"); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); int64_t inputRank = inputType.getRank(); SmallVector inputShapeValues; @@ -68,7 +68,7 @@ struct ONNXTileOpLoweringToStablehlo : public ConversionPattern { } RankedTensorType multiplesType = - multiples.getType().dyn_cast(); + mlir::dyn_cast(multiples.getType()); Type multiplesElementType = multiplesType.getElementType(); int64_t multiplesRank = multiplesType.getRank(); if (multiplesRank != 1) diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Transpose.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Transpose.cpp index 350f8f5de1..18e9ded10c 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Transpose.cpp @@ -39,8 +39,8 @@ struct ONNXTransposeOpLoweringToStablehlo : public ConversionPattern { // Convert the output type Type outputType = *op->result_type_begin(); - assert(outputType.isa() && "Expected ShapedType"); - ShapedType outputShapedType = outputType.cast(); + assert(mlir::isa(outputType) && "Expected ShapedType"); + ShapedType outputShapedType = mlir::cast(outputType); int64_t rank = outputShapedType.getShape().size(); // Attributes diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Unsqueeze.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Unsqueeze.cpp index e517b89d37..92bbd6f734 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Unsqueeze.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Unsqueeze.cpp @@ -37,7 +37,7 @@ struct ONNXUnsqueezeOpLoweringToStablehlo : public ConversionPattern { Value axes = unsqueezeOp.getAxes(); assert(isRankedShapedType(data.getType()) && "data must be ranked Shaped Type"); - ShapedType dataType = data.getType().cast(); + ShapedType dataType = mlir::cast(data.getType()); int64_t rank = dataType.getRank(); // Unused; for example, axles can be read from it. @@ -49,7 +49,7 @@ struct ONNXUnsqueezeOpLoweringToStablehlo : public ConversionPattern { SmallVector axesList; if (ElementsAttr axesAttr = getElementAttributeFromONNXValue(axes)) { for (IntegerAttr value : axesAttr.getValues()) { - int64_t axis = value.cast().getInt(); + int64_t axis = mlir::cast(value).getInt(); if (axis < 0) axis += rank; axesList.push_back(axis); diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 5fe7ff0bdf..2f0a357249 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -73,7 +73,7 @@ void FrontendToTosaLoweringPass::runOnOperation() { // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> std::optional { - if (isTOSASignedInt(type) || isTOSAFloat(type) || type.isa()) + if (isTOSASignedInt(type) || isTOSAFloat(type) || mlir::isa(type)) return type; return std::nullopt; }); diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index 5499a50535..004850b94c 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -56,9 +56,9 @@ Value TosaBuilder::createConst( } bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) { - int64_t firstRank = valueRange[0].getType().cast().getRank(); + int64_t firstRank = mlir::cast(valueRange[0].getType()).getRank(); for (Value operand : valueRange) { - auto operandType = operand.getType().cast(); + auto operandType = mlir::cast(operand.getType()); if (firstRank != operandType.getRank()) return true; } @@ -66,7 +66,7 @@ bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) { } Value TosaBuilder::expandRank(Value input, int64_t rank) { - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); int64_t inputRank = inputType.getRank(); assert(inputRank <= rank && "cannot reduce rank of operation"); if (inputRank == rank) @@ -82,13 +82,13 @@ llvm::SmallVector TosaBuilder::equalizeRanks(ValueRange valueRange) { // Get highest rank from the operands. int64_t maxRank = 0; for (auto type : valueRange.getTypes()) { - int64_t currentRank = type.cast().getRank(); + int64_t currentRank = mlir::cast(type).getRank(); maxRank = std::max(maxRank, currentRank); } llvm::SmallVector reshapedValues; // Iterate through all values comparing the rank. for (auto value : valueRange) { - auto shapedType = value.getType().cast(); + auto shapedType = mlir::cast(value.getType()); int64_t currentRank = shapedType.getRank(); // Only add a reshape op if necessary. if (maxRank > currentRank) { @@ -128,12 +128,12 @@ Value TosaBuilder::getSplattedConst(float val, llvm::ArrayRef shape) { } Value TosaBuilder::transpose(mlir::Value &value, llvm::ArrayRef perm) { - int64_t valueRank = value.getType().cast().getRank(); + int64_t valueRank = mlir::cast(value.getType()).getRank(); assert((valueRank == (int64_t)perm.size()) && "value and perm vector don't have the same rank"); // Create Permutation Const Value permList = this->getConst(perm, {valueRank}); - auto valueType = value.getType().cast(); + auto valueType = mlir::cast(value.getType()); // get new value type Type newValueType = RankedTensorType::get( llvm::SmallVector( @@ -153,14 +153,14 @@ Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef size, tosa::CreateOpAndInfer(rewriter(), loc(), RankedTensorType::get( llvm::SmallVector(size.size(), ShapedType::kDynamic), - inputConst.getType().cast().getElementType()), + mlir::cast(inputConst.getType()).getElementType()), inputConst, startAttr, sizeAttr); return newSliceInput; } Value TosaBuilder::reshape(mlir::Value &value, llvm::ArrayRef shape) { auto shapeAttr = rewriter().getDenseI64ArrayAttr(shape); - auto valueType = value.getType().cast(); + auto valueType = mlir::cast(value.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(shape.size(), ShapedType::kDynamic), valueType.getElementType()); @@ -174,7 +174,7 @@ Value TosaBuilder::mul(mlir::Value &lhs, mlir::Value &rhs, int32_t shift) { lhs = valueVec[0]; rhs = valueVec[1]; } - auto lhsType = lhs.getType().cast(); + auto lhsType = mlir::cast(lhs.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), lhsType.getElementType()); @@ -183,8 +183,8 @@ Value TosaBuilder::mul(mlir::Value &lhs, mlir::Value &rhs, int32_t shift) { } Value TosaBuilder::intdiv(mlir::Value &lhs, mlir::Value &rhs) { - Type lhsElementType = lhs.getType().cast().getElementType(); - Type rhsElementType = rhs.getType().cast().getElementType(); + Type lhsElementType = mlir::cast(lhs.getType()).getElementType(); + Type rhsElementType = mlir::cast(rhs.getType()).getElementType(); assert((lhsElementType.isSignlessInteger(32) && rhsElementType.isSignlessInteger(32)) && "Tosa IntDivOp needs 32-bit signless integer inputs"); @@ -195,7 +195,7 @@ Value TosaBuilder::intdiv(mlir::Value &lhs, mlir::Value &rhs) { rhs = valueVec[1]; } - auto lhsType = lhs.getType().cast(); + auto lhsType = mlir::cast(lhs.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), lhsElementType); @@ -204,7 +204,7 @@ Value TosaBuilder::intdiv(mlir::Value &lhs, mlir::Value &rhs) { } Value TosaBuilder::reciprocal(mlir::Value &input) { - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(inputType.getRank(), ShapedType::kDynamic), inputType.getElementType()); @@ -219,7 +219,7 @@ Value TosaBuilder::binaryOp(mlir::Value &lhs, mlir::Value &rhs) { lhs = valueVec[0]; rhs = valueVec[1]; } - auto lhsType = lhs.getType().cast(); + auto lhsType = mlir::cast(lhs.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), lhsType.getElementType()); @@ -246,10 +246,10 @@ ElementsAttr IndexExprBuilderForTosa::getConst(Value value) { } if (auto constOp = dyn_cast_or_null(definingOp)) { if (constOp.getValueAttr()) - return constOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constOp.getValueAttr()); } else if (auto constOp = dyn_cast_or_null(definingOp)) { if (constOp.getValue().has_value()) - return constOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constOp.getValueAttr()); } return nullptr; } diff --git a/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp b/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp index 508d71d159..dc35d39099 100644 --- a/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp @@ -45,7 +45,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op, // Set up constants outside of loop const int64_t sizeOfSliceInput = weightShape[1]; const int64_t sizeOfSliceKernel = weightShape[0] / groups; - auto newInputShape = newInput.getType().cast().getShape(); + auto newInputShape = mlir::cast(newInput.getType()).getShape(); llvm::SmallVector inputSize = { newInputShape[0], newInputShape[1], newInputShape[2], sizeOfSliceInput}; @@ -69,7 +69,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op, // Create conv Type newConvOutputType = RankedTensorType::get( llvm::SmallVector(4, ShapedType::kDynamic), - resultType.cast().getElementType()); + mlir::cast(resultType).getElementType()); Value tempConv2D = tosa::CreateOpAndInfer(rewriter, op->getLoc(), newConvOutputType, newSliceInput, newSliceWeight, newSliceBias, pads, strides, dilations); @@ -79,7 +79,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op, // Create concat op Type newConcatOutputType = RankedTensorType::get( llvm::SmallVector(4, ShapedType::kDynamic), - resultType.cast().getElementType()); + mlir::cast(resultType).getElementType()); Value conv2D = tosa::CreateOpAndInfer( rewriter, op->getLoc(), newConcatOutputType, sliceValues, 3); return conv2D; @@ -103,8 +103,8 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { auto weights = adaptor.getW(); auto bias = adaptor.getB(); - auto inputType = input.getType().cast(); - auto weightType = weights.getType().cast(); + auto inputType = mlir::cast(input.getType()); + auto weightType = mlir::cast(weights.getType()); // Get shapehelper for autopad attributes IndexExprBuilderForTosa createTosaIE(rewriter, convOp->getLoc()); @@ -126,7 +126,7 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { // Convert weights [OC,IC,KH,KW] -> [OC,KH,KW,IC] Value newWeight = tosaBuilder.transpose(weights, {0, 2, 3, 1}); - if (bias.getType().isa()) { + if (mlir::isa(bias.getType())) { DenseElementsAttr newBiasAttr = DenseElementsAttr::get( RankedTensorType::get({weightShape[0]}, rewriter.getF32Type()), {0.0F}); @@ -153,7 +153,7 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { if (group == 1) { Type newConvOutputType = RankedTensorType::get( llvm::SmallVector(4, ShapedType::kDynamic), - resultType.cast().getElementType()); + mlir::cast(resultType).getElementType()); conv2D = tosa::CreateOpAndInfer(rewriter, convOp->getLoc(), newConvOutputType, newInput, newWeight, bias, diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 13f13c7152..2e105d2dc5 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -53,12 +53,12 @@ class ONNXBinaryElementwiseOpLoweringToTOSA auto loc = op.getLoc(); Value lhs = adaptor.getA(); - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = mlir::dyn_cast(lhs.getType()); Value rhs = adaptor.getB(); - auto rhsType = rhs.getType().dyn_cast(); + auto rhsType = mlir::dyn_cast(rhs.getType()); - auto resultType = op.getResult().getType().template dyn_cast(); + auto resultType = mlir::dyn_cast(op.getResult().getType()); if (!lhsType || !rhsType || !resultType) { return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); } @@ -137,7 +137,7 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getA(); Value rhs = adaptor.getB(); - auto resultType = op.getResult().getType().template cast(); + auto resultType = mlir::cast(op.getResult().getType()); Type resultElementType = resultType.getElementType(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp index 7f131ad779..556de3f7d7 100644 --- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp @@ -47,17 +47,16 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { int64_t transB = adaptor.getTransB(); FloatAttr alpha = adaptor.getAlphaAttr(); FloatAttr beta = adaptor.getBetaAttr(); - auto AType = A.getType().cast(); - auto BType = B.getType().cast(); + auto AType = mlir::cast(A.getType()); + auto BType = mlir::cast(B.getType()); auto shapeA = AType.getShape(); auto shapeB = BType.getShape(); - auto resultType = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto resultType = mlir::cast( + getTypeConverter()->convertType(op.getResult().getType())); // C is optional, if it's not there, we need to be aware of it for later // computations - bool isCPresent = C.getType().isa(); + bool isCPresent = mlir::isa(C.getType()); // ONNX uses HW matrix as input and output as it runs a matrix // multiplication. TOSA implements it as a batch matrix multiplication, // meaning the input and output are NHW. As such, there is a need to add @@ -139,11 +138,12 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { /// Check if the bias (C) needs broadcasting when we convert GEMM to FC. static bool hasCCorrectShape(TensorType A, TensorType B, Value C) { - if (!C.getType().isa()) + if (!mlir::isa(C.getType())) return false; ArrayRef AShape = A.getShape(); ArrayRef BShape = B.getShape(); - ArrayRef CShape = C.getType().cast().getShape(); + ArrayRef CShape = + mlir::cast(C.getType()).getShape(); // In the case of GemmToFC, transB is set meaning that B shapes will be // interverted so we check B[0]. Also, C is supposed to be of rank 1 so we // only need to check C[0]. @@ -162,14 +162,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { Value B = op.getB(); Value C = op.getC(); - auto AType = A.getType().cast(); - auto BType = B.getType().cast(); + auto AType = mlir::cast(A.getType()); + auto BType = mlir::cast(B.getType()); - bool isCPresent = !C.getType().isa(); + bool isCPresent = !mlir::isa(C.getType()); // If C is present, it can only be of rank 1, if the rank is not 1, return // false. - if (C.getType().isa() && - C.getType().cast().getRank() != 1) + if (mlir::isa(C.getType()) && + mlir::cast(C.getType()).getRank() != 1) return false; // Input tensor must be of rank 2. @@ -196,7 +196,8 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { bool needsBroadcasting = !hasCCorrectShape(AType, BType, C); Value dummyC = C; if (!isCPresent || needsBroadcasting) { - ArrayRef cformat(resultType.cast().getShape()[1]); + ArrayRef cformat( + mlir::cast(resultType).getShape()[1]); std::vector elements = {}; for (int i = 0; i < cformat[0]; ++i) elements.push_back(0.0F); diff --git a/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp b/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp index b10e9cb852..edaea512bb 100644 --- a/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp +++ b/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp @@ -41,11 +41,11 @@ class ONNXReduceMeanLoweringToTOSA auto keepDims = adaptor.getKeepdims(); auto noOpIfAxesEmpty = adaptor.getNoopWithEmptyAxes(); - auto outputType = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto outputType = mlir::cast( + getTypeConverter()->convertType(op.getResult().getType())); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = + mlir::dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure(op, "input type not a ranked tensor."); @@ -53,7 +53,8 @@ class ONNXReduceMeanLoweringToTOSA llvm::SmallVector axesVec; if (isNoneValue(axesValue) && !noOpIfAxesEmpty) { // if not present all axes are reduced - const int64_t numberOfAxes = input.getType().cast().getRank(); + const int64_t numberOfAxes = + mlir::cast(input.getType()).getRank(); llvm::SmallVector allDims(numberOfAxes); std::iota(std::begin(allDims), std::end(allDims), 0); axesVec.append(allDims); diff --git a/src/Conversion/ONNXToTOSA/Math/Softmax.cpp b/src/Conversion/ONNXToTOSA/Math/Softmax.cpp index fb7dff835e..81321a5754 100644 --- a/src/Conversion/ONNXToTOSA/Math/Softmax.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Softmax.cpp @@ -73,9 +73,9 @@ class ONNXSoftmaxLoweringToTOSA : public OpConversionPattern { Value input = adaptor.getInput(); // softmax = exp(logits) / reduce_sum(exp(logits), -1) auto outputType = - op.getResult().getType().template dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); auto inputType = - adaptor.getInput().getType().template dyn_cast(); + mlir::dyn_cast(adaptor.getInput().getType()); // Not a ranked tensor input/output if (!outputType || !inputType) { diff --git a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp index 142d60c608..9874961036 100644 --- a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp +++ b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp @@ -48,7 +48,7 @@ void handleIncludePadAttr( rewriter, intValues, loc, {0, 0, 0, 0}, {}); auto constTosaTensor = tosaBuilder.getSplattedConst(0.0); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto padOp = tosa::CreateOpAndInfer(rewriter, loc, mlir::RankedTensorType::get( llvm::SmallVector( diff --git a/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp b/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp index f46a756741..c530c1e152 100644 --- a/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp +++ b/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp @@ -44,7 +44,7 @@ class ONNXMaxPoolSingleOutOpLoweringToTOSA : public ConversionPattern { IntegerAttr storageOrder = adaptor.getStorageOrderAttr(); ArrayAttr dilations = adaptor.getDilationsAttr(); - if (input.getType().isa()) { + if (mlir::isa(input.getType())) { return rewriter.notifyMatchFailure( op, "memrefs as inputs are unsupported by TOSA"); } diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index ace82bce0f..88f47497a7 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -77,14 +77,15 @@ mlir::FailureOr convertPoolOp( //===----------------------------------------------------------------------===// inline bool isTOSASignedInt(mlir::Type type) { - mlir::IntegerType intType = type.dyn_cast(); + mlir::IntegerType intType = mlir::dyn_cast(type); std::set intWidth{1, 8, 16, 32, 48, 64}; return intType && intType.isSignless() && (intWidth.find(intType.getWidth()) != intWidth.end()); } inline bool isTOSAFloat(mlir::Type type) { - return type.isa(); + return mlir::isa( + type); } //===----------------------------------------------------------------------===// diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc index 24983ee811..99558a5bd0 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc @@ -23,7 +23,7 @@ std::optional convertReduceOpCommon( mlir::ElementsAttr axesElems, bool keepDims, mlir::Type reduceElementType) { TosaBuilder tosaBuilder(rewriter, op->getLoc()); mlir::RankedTensorType inputType = - inputValue.getType().dyn_cast(); + mlir::dyn_cast(inputValue.getType()); if (!inputType) return std::nullopt; @@ -140,7 +140,7 @@ inline mlir::LogicalResult getAvgPool2dAccType(mlir::PatternRewriter &rewriter, // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time // FP16 is supported, the accumulator type can be selected based on trade-off // between performance and accuracy. Set to FP32 by default. - accType = inputETy.isa() + accType = mlir::isa(inputETy) ? mlir::TypeAttr::get(rewriter.getF32Type()) : mlir::TypeAttr::get(rewriter.getIntegerType(32)); @@ -162,7 +162,7 @@ mlir::FailureOr convertPoolOp( TosaBuilder tosaBuilder(rewriter, loc); mlir::Value input = adaptor.getX(); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); if (inputType.getShape().size() != 4) { (void)rewriter.notifyMatchFailure(op, "TOSA only supports 2d pooling"); return mlir::failure(); diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp index 24e1e73024..56eb050491 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp @@ -37,14 +37,14 @@ mlir::RankedTensorType reduceAxisToOne(llvm::ArrayRef shape, // Returns the value TOSA ConstOp template T getValueFromTosaConst(mlir::Value &val) { - return val.getDefiningOp().getValue().cast(); + return mlir::cast(val.getDefiningOp().getValue()); } // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, - mlir::Type result_ty, Args &&... args) { + mlir::Type result_ty, Args &&...args) { auto op = rewriter.create(loc, result_ty, args...); mlir::InferShapedTypeOpInterface shapeInterface = @@ -67,13 +67,13 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, auto predictedShape = returnedShapes[0]; if (predictedShape.hasRank()) updateType(nullptr, op, predictedShape.getDims(), - result_ty.cast().getElementType()); + mlir::cast(result_ty).getElementType()); return op; } template void CreateReplaceOpAndInfer(mlir::PatternRewriter &rewriter, - mlir::Operation *op, mlir::Type result_ty, Args &&... args) { + mlir::Operation *op, mlir::Type result_ty, Args &&...args) { auto result = CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); rewriter.replaceOp(op, result->getResults()); diff --git a/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp b/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp index a13c894acb..051a208898 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp @@ -39,13 +39,13 @@ class ONNXConstOpLoweringToTOSA : public OpConversionPattern { op, "tosa.const does not support sparse value"); } Attribute currentAttr = valueAttr.value(); - if (!currentAttr.isa()) { + if (!mlir::isa(currentAttr)) { return rewriter.notifyMatchFailure( op, "tosa.const does not support non-tensor types"); } Type resultType = getTypeConverter()->convertType(op.getResult().getType()); rewriter.replaceOpWithNewOp( - op, resultType, currentAttr.cast()); + op, resultType, mlir::cast(currentAttr)); return success(); } }; diff --git a/src/Conversion/ONNXToTOSA/Tensor/Reshape.cpp b/src/Conversion/ONNXToTOSA/Tensor/Reshape.cpp index aa44d29060..1adfb841d6 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Reshape.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Reshape.cpp @@ -38,8 +38,8 @@ class ONNXReshapeOpLoweringToTOSA : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "dynamic shapes not supported"); Value data = op.getData(); - Value reshapeOp = - tosaBuilder.reshape(data, outputTy.cast().getShape()); + Value reshapeOp = tosaBuilder.reshape( + data, mlir::cast(outputTy).getShape()); rewriter.replaceOp(op, {reshapeOp}); return success(); } diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp index d3535648e2..3f3269a029 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp @@ -178,10 +178,10 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { TosaBuilder tosaBuilder(rewriter, loc); Value input = adaptor.getX(); - auto inputType = input.getType().dyn_cast(); + auto inputType = mlir::dyn_cast(input.getType()); auto resultType = - resizeOp.getResult().getType().dyn_cast(); + mlir::dyn_cast(resizeOp.getResult().getType()); StringRef mode = adaptor.getMode(); StringRef nearestMode = adaptor.getNearestMode(); @@ -305,7 +305,7 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { Type newOutputType = RankedTensorType::get(llvm::SmallVector( inputType.getRank(), ShapedType::kDynamic), - resultType.cast().getElementType()); + mlir::cast(resultType).getElementType()); Value resize = tosa::CreateOpAndInfer(rewriter, loc, newOutputType, newInput, scale, offset, border, resizeModeAttr); diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 53475325a6..99198a0182 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -365,11 +365,11 @@ ElementsAttr IndexExprBuilderForKrnl::getConst(mlir::Value value) { auto definingOp = value.getDefiningOp(); if (auto globalOp = dyn_cast_or_null(definingOp)) { if (globalOp.getValue().has_value()) - return globalOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(globalOp.getValueAttr()); } else if (auto globalOp = dyn_cast_or_null(definingOp)) { if (globalOp.getValue().has_value()) - return globalOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(globalOp.getValueAttr()); } return nullptr; } diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 025cce55f7..1d89b46f1e 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -45,7 +45,7 @@ def Krnl_Dialect : Dialect { ]; } -def StringType : Type()">, "string type">; +def StringType : Type($_self)">, "string type">; // Require regions to have krnl.terminate terminator operation. def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">; @@ -56,9 +56,9 @@ def KrnlCallOp : Op { def KrnlLoadOp : Op().getElementType()">, + "mlir::cast($_self).getElementType()">, MemRefsNormalizable]> { let summary = "A Krnl operation to load data from the memref."; @@ -657,7 +657,7 @@ def KrnlLoadOp : Op:$indices), [{ - auto memrefType = memref.getType().cast(); + auto memrefType = mlir::cast(memref.getType()); $_state.addOperands(memref); $_state.addOperands(indices); $_state.types.push_back(memrefType.getElementType()); @@ -668,7 +668,7 @@ def KrnlLoadOp : Op(); + return mlir::cast(getMemref().getType()); } }]; } @@ -676,7 +676,7 @@ def KrnlLoadOp : Op().getElementType()">, + "mlir::cast($_self).getElementType()">, MemRefsNormalizable]> { let summary = "A Krnl operation to store data to the memref."; let description = [{ @@ -704,7 +704,7 @@ def KrnlStoreOp : Op(); + return mlir::cast(getMemref().getType()); } }]; } @@ -741,7 +741,7 @@ def KrnlGetLinearOffsetIndexOp : Op(); + return mlir::cast(getMemref().getType()); } /// Returns the affine map used to index the memref for this operation. @@ -792,7 +792,7 @@ def KrnlPrefetchOp : Op(); + return mlir::cast(getMemref().getType()); } /// Implements the AffineMapAccessInterface. @@ -1154,10 +1154,10 @@ def KrnlMatMulOp : Op().getElementType()">, + "mlir::cast($_self).getElementType()">, TypesMatchWith<"type of 'padValue' matches element type of 'buffer'", "buffer", "padValue", - "$_self.cast().getElementType()">, + "mlir::cast($_self).getElementType()">, MemRefsNormalizable]> { let summary = "Copy to buffer."; let description = [{ @@ -1279,7 +1279,7 @@ def KrnlMemsetOp : Op, TypesMatchWith<"type of 'value' matches element type of 'dest'", "dest", "value", - "$_self.cast().getElementType()">]> { + "mlir::cast($_self).getElementType()">]> { let summary = "Set buffer to a given value."; let description = [{ Krnl operation that sets a buffer to a given value. diff --git a/src/Dialect/Krnl/KrnlHelper.cpp b/src/Dialect/Krnl/KrnlHelper.cpp index 86951734b0..fc989952e7 100644 --- a/src/Dialect/Krnl/KrnlHelper.cpp +++ b/src/Dialect/Krnl/KrnlHelper.cpp @@ -160,7 +160,7 @@ DenseElementsAttr getDenseElementAttributeFromKrnlValue(Value value) { dyn_cast_or_null(value.getDefiningOp()); if (globalOp) if (globalOp.getValue().has_value()) - return globalOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(globalOp.getValueAttr()); return nullptr; } diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index 17f7693d92..6b32d772ee 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -146,7 +146,7 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, // Create funcName std::string name = op->getName().getStringRef().str(); std::replace(name.begin(), name.end(), '.', '_'); - ShapedType resultType = resultVals[0].getType().cast(); + ShapedType resultType = mlir::cast(resultVals[0].getType()); Type elementType = resultType.getElementType(); std::string funcNameStr = name + "_" + typeToString(elementType); @@ -360,10 +360,10 @@ void KrnlIterateOp::print(OpAsmPrinter &printer) { printer.printOperand(var); printer << " = "; krnl::printBound( - (*boundItr++).cast(), operandItr, "max", printer); + mlir::cast(*boundItr++), operandItr, "max", printer); printer << " to "; krnl::printBound( - (*boundItr++).cast(), operandItr, "min", printer); + mlir::cast(*boundItr++), operandItr, "min", printer); delimiter = ", "; } @@ -456,7 +456,7 @@ struct LoopParser { boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) return failure(); - if (auto affineMapAttr = boundAttr.dyn_cast()) { + if (auto affineMapAttr = mlir::dyn_cast(boundAttr)) { unsigned currentNumOperands = result.operands.size(); unsigned numDims = 0; if (affine::parseDimAndSymbolList(parser, result.operands, numDims)) @@ -488,7 +488,7 @@ struct LoopParser { return success(); } - if (auto integerAttr = boundAttr.dyn_cast()) { + if (auto integerAttr = mlir::dyn_cast(boundAttr)) { AffineMap map = builder.getConstantAffineMap(integerAttr.getValue().getSExtValue()); boundMaps.emplace_back(AffineMapAttr::get(map)); @@ -700,7 +700,7 @@ void KrnlGetInductionVariableValueOp::build(::mlir::OpBuilder &odsBuilder, // only 1D vectors. void KrnlVectorTypeCastOp::build(OpBuilder &builder, OperationState &state, Value sourceMemRef, int64_t vectorLen) { - MemRefType sourceType = sourceMemRef.getType().cast(); + MemRefType sourceType = mlir::cast(sourceMemRef.getType()); Type elementType = sourceType.getElementType(); auto sourceShape = sourceType.getShape(); int rank = sourceShape.size(); @@ -723,8 +723,8 @@ bool KrnlVectorTypeCastOp::areCastCompatible( return false; Type a = inputs.front(), b = outputs.front(); - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); + auto aT = mlir::dyn_cast(a); + auto bT = mlir::dyn_cast(b); if (!aT || !bT) return false; @@ -749,10 +749,10 @@ bool KrnlVectorTypeCastOp::areCastCompatible( return false; // Source memref can't have vector element type. - if (auto shapedEltType = aT.getElementType().dyn_cast()) + if (auto shapedEltType = mlir::dyn_cast(aT.getElementType())) return false; - auto shapedEltTypeB = bT.getElementType().dyn_cast(); + auto shapedEltTypeB = mlir::dyn_cast(bT.getElementType()); if (!shapedEltTypeB) return false; @@ -782,7 +782,7 @@ static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && !cast.getOperand().getType().isa()) { + if (cast && !mlir::isa(cast.getOperand().getType())) { operand.set(cast.getOperand()); folded = true; } @@ -847,11 +847,11 @@ void KrnlMatMulOp::build(::mlir::OpBuilder &odsBuilder, LogicalResult KrnlMatMulOp::verify() { KrnlMatMulOpAdaptor operandAdaptor = KrnlMatMulOpAdaptor(*this); uint64_t aRank = - operandAdaptor.getA().getType().cast().getShape().size(); + mlir::cast(operandAdaptor.getA().getType()).getShape().size(); uint64_t bRank = - operandAdaptor.getB().getType().cast().getShape().size(); + mlir::cast(operandAdaptor.getB().getType()).getShape().size(); uint64_t cRank = - operandAdaptor.getC().getType().cast().getShape().size(); + mlir::cast(operandAdaptor.getC().getType()).getShape().size(); if (!(aRank >= 2 && bRank >= 2 && cRank >= 2)) return emitOpError("currently only support ranks >=2"); if (operandAdaptor.getAGlobalIndexMemStart().size() != aRank) @@ -977,7 +977,7 @@ LogicalResult KrnlCopyFromBufferOp::verify() { IndexExprBuilderForAnalysis createIE(getLoc()); int64_t bufferRank = createIE.getShapedTypeRank(opAdaptor.getBuffer()); int64_t destRank = - opAdaptor.getDest().getType().cast().getShape().size(); + mlir::cast(opAdaptor.getDest().getType()).getShape().size(); int64_t startRank = opAdaptor.getStarts().size(); if (!createIE.isLiteralShape(opAdaptor.getBuffer())) return emitOpError("buffer expect constant dimensions"); @@ -1192,18 +1192,18 @@ ParseResult KrnlPrefetchOp::parse(OpAsmParser &parser, OperationState &result) { parser.resolveOperands(mapOperands, indexTy, result.operands)) return failure(); - if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + if (!(readOrWrite == "read") && !(readOrWrite == "write")) return parser.emitError( parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); result.addAttribute(KrnlPrefetchOp::getIsWriteAttrStrName(), - parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + parser.getBuilder().getBoolAttr(readOrWrite == "write")); - if (!cacheType.equals("data") && !cacheType.equals("instr")) + if (!(cacheType == "data") && !(cacheType == "instr")) return parser.emitError( parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); result.addAttribute(KrnlPrefetchOp::getIsDataCacheAttrStrName(), - parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + parser.getBuilder().getBoolAttr(cacheType == "data")); return success(); } diff --git a/src/Dialect/Krnl/KrnlTypes.cpp b/src/Dialect/Krnl/KrnlTypes.cpp index 69dd142f9f..f725eb318b 100644 --- a/src/Dialect/Krnl/KrnlTypes.cpp +++ b/src/Dialect/Krnl/KrnlTypes.cpp @@ -23,11 +23,11 @@ namespace krnl { void customizeTypeConverter(LLVMTypeConverter &typeConverter) { typeConverter.addConversion([&](MemRefType type) -> std::optional { Type elementType = type.getElementType(); - if (!elementType.isa()) + if (!mlir::isa(elementType)) return std::nullopt; - elementType = - elementType.cast().getLLVMType(type.getContext()); + elementType = mlir::cast(elementType) + .getLLVMType(type.getContext()); return typeConverter.convertType( MemRefType::get(type.getShape(), elementType)); }); diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 0407aeab48..d36aa837fe 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -52,11 +52,11 @@ namespace onnx_mlir { // MLIR unsigned integer. /* static */ bool MathBuilder::isVector(Type type) { - return type.dyn_cast() != nullptr; + return mlir::dyn_cast(type) != nullptr; } /* static */ Type MathBuilder::elementTypeWithVector(Type elementOrVectorType) { - VectorType vectorType = elementOrVectorType.dyn_cast(); + VectorType vectorType = mlir::dyn_cast(elementOrVectorType); if (vectorType) return vectorType.getElementType(); return elementOrVectorType; @@ -71,7 +71,8 @@ namespace onnx_mlir { /* static */ bool MathBuilder::isIntegerWithVector(Type elementOrVectorType) { Type elementType = elementTypeWithVector(elementOrVectorType); - return elementType.isa() || elementType.isa(); + return mlir::isa(elementType) || + mlir::isa(elementType); } /* static */ bool MathBuilder::isUnsignedIntegerWithVector( @@ -82,7 +83,7 @@ namespace onnx_mlir { /* static */ bool MathBuilder::isFloatWithVector(Type elementOrVectorType) { Type elementType = elementTypeWithVector(elementOrVectorType); - return elementType.isa(); + return mlir::isa(elementType); } Value MathBuilder::abs(Value val) const { @@ -119,7 +120,7 @@ Value MathBuilder::add(Value lhs, Value rhs) const { if (isIntegerWithVector(lhs.getType())) { Type elemType = elementTypeWithVector(lhs.getType()); if (elemType.isUnsignedInteger()) { - unsigned elemWidth = elemType.cast().getWidth(); + unsigned elemWidth = mlir::cast(elemType).getWidth(); Value castLhs = castToSignless(lhs, elemWidth); Value castRhs = castToSignless(rhs, elemWidth); Value castAdd = @@ -147,7 +148,7 @@ Value MathBuilder::mul(Value lhs, Value rhs) const { if (isIntegerWithVector(lhs.getType())) { Type elemType = elementTypeWithVector(lhs.getType()); if (elemType.isUnsignedInteger()) { - unsigned elemWidth = elemType.cast().getWidth(); + unsigned elemWidth = mlir::cast(elemType).getWidth(); Value castLhs = castToSignless(lhs, elemWidth); Value castRhs = castToSignless(rhs, elemWidth); Value castMul = @@ -462,10 +463,10 @@ Value MathBuilder::constant(Type type, double val) const { .Default([](Type) { llvm_unreachable("unsupported element type"); }); assert(constant != nullptr && "Expecting valid constant value"); - if (type.isa()) { + if (mlir::isa(type)) { // For vectors, need to splat the constant. MultiDialectBuilder create(*this); - VectorType vecType = type.dyn_cast(); + VectorType vecType = mlir::dyn_cast(type); constant = create.vec.splat(vecType, constant); } return constant; @@ -572,10 +573,10 @@ Value MathBuilder::negativeInf(Type type) const { TypedAttr attr = negativeInfAttr(elementType); Value constant = b().create(loc(), attr); assert(constant != nullptr && "Expecting valid constant value"); - if (type.isa()) { + if (mlir::isa(type)) { // For vectors, need to splat the constant. MultiDialectBuilder create(*this); - VectorType vecType = type.dyn_cast(); + VectorType vecType = mlir::dyn_cast(type); constant = create.vec.splat(vecType, constant); } return constant; @@ -587,10 +588,10 @@ Value MathBuilder::positiveInf(Type type) const { TypedAttr attr = positiveInfAttr(elementType); Value constant = b().create(loc(), attr); assert(constant != nullptr && "Expecting valid constant value"); - if (type.isa()) { + if (mlir::isa(type)) { // For vectors, need to splat the constant. MultiDialectBuilder create(*this); - VectorType vecType = type.dyn_cast(); + VectorType vecType = mlir::dyn_cast(type); constant = create.vec.splat(vecType, constant); } return constant; @@ -617,10 +618,10 @@ Value MathBuilder::createArithCmp( // best of my understanding. Value MathBuilder::castToSignless(Value val, int64_t width) const { Type valType = val.getType(); - VectorType vecType = valType.dyn_cast(); + VectorType vecType = mlir::dyn_cast(valType); Type valElemType = elementTypeWithVector(valType); - assert(valElemType.isa() && !valElemType.isSignlessInteger() && - "Expecting signed integer type"); + assert(mlir::isa(valElemType) && + !valElemType.isSignlessInteger() && "Expecting signed integer type"); Type destType = getTypeWithVector(vecType, b().getIntegerType(width)); return b() .create(loc(), destType, val) @@ -629,9 +630,9 @@ Value MathBuilder::castToSignless(Value val, int64_t width) const { Value MathBuilder::castToUnsigned(Value val, int64_t width) const { Type valType = val.getType(); - VectorType vecType = valType.dyn_cast(); + VectorType vecType = mlir::dyn_cast(valType); Type valElemType = elementTypeWithVector(valType); - assert(valElemType.isa() && "Expecting integer type"); + assert(mlir::isa(valElemType) && "Expecting integer type"); Type destType = getTypeWithVector(vecType, b().getIntegerType(width, false /*signed*/)); return b() @@ -643,8 +644,8 @@ Value MathBuilder::castToUnsigned(Value val, int64_t width) const { Value MathBuilder::cast(Type destType, Value src) const { // Get element type and vector types (if any, i.e. possibly nullptr). Type srcType = src.getType(); - VectorType srcVecType = srcType.dyn_cast(); - VectorType destVecType = destType.dyn_cast(); + VectorType srcVecType = mlir::dyn_cast(srcType); + VectorType destVecType = mlir::dyn_cast(destType); Type srcElemType = elementTypeWithVector(srcType); Type destElemType = elementTypeWithVector(destType); // Make sure we don't mix vector and scalars. @@ -655,7 +656,7 @@ Value MathBuilder::cast(Type destType, Value src) const { return src; // Process index types first. - if (srcElemType.isa()) { + if (mlir::isa(srcElemType)) { // If the source is an index type, first convert it into a signless int of // size 64. srcElemType = b().getIntegerType(64); @@ -664,7 +665,7 @@ Value MathBuilder::cast(Type destType, Value src) const { } bool destIsIndex = false; Type savedDestType = destType; // Used when destIsIndex is true. - if (destElemType.isa()) { + if (mlir::isa(destElemType)) { // If the dest is an index type, pretend for now that we want it to be // converted to signless int of size 64. destElemType = b().getIntegerType(64); @@ -675,9 +676,11 @@ Value MathBuilder::cast(Type destType, Value src) const { // Only support Integer or Float type at this stage. Index were transformed // to signless int. // TODO: add support for shaped tensor (MemRef, Vector, Tensor?) if needed. - assert((srcElemType.isa() || srcElemType.isa()) && + assert((mlir::isa(srcElemType) || + mlir::isa(srcElemType)) && "support only float or int"); - assert((destElemType.isa() || destElemType.isa()) && + assert((mlir::isa(destElemType) || + mlir::isa(destElemType)) && "support only float or int"); // Get source and dest type width. int64_t srcElemWidth = srcElemType.getIntOrFloatBitWidth(); @@ -691,7 +694,7 @@ Value MathBuilder::cast(Type destType, Value src) const { // Handle boolean first because they need special handling. // Boolean to int/float conversions. Boolean are unsigned. if (srcElemType.isInteger(1)) { - if (destElemType.isa()) { + if (mlir::isa(destElemType)) { return b().create(loc(), destType, src); } else { Value dest = b().create(loc(), destType, src); @@ -704,9 +707,10 @@ Value MathBuilder::cast(Type destType, Value src) const { // Int/Float to booleans, just compare value to be unequal zero. if (destElemType.isInteger(1)) { Type constantType = srcType; - if (srcElemType.isa() && !srcElemType.isSignlessInteger()) { + if (mlir::isa(srcElemType) && + !srcElemType.isSignlessInteger()) { // An integer constant must be signless. - unsigned srcElemWidth = srcElemType.cast().getWidth(); + unsigned srcElemWidth = mlir::cast(srcElemType).getWidth(); constantType = getTypeWithVector( srcVecType, IntegerType::get(srcElemType.getContext(), srcElemWidth)); src = castToSignless(src, srcElemWidth); @@ -716,7 +720,7 @@ Value MathBuilder::cast(Type destType, Value src) const { } // Float to float conversions. - if (srcElemType.isa() && destElemType.isa()) { + if (mlir::isa(srcElemType) && mlir::isa(destElemType)) { assert((bitExtend || bitTrunc) && "expected extend or trunc"); if (bitExtend) return b().create(loc(), destType, src); @@ -725,7 +729,8 @@ Value MathBuilder::cast(Type destType, Value src) const { } // Float to int conversions. - if (srcElemType.isa() && destElemType.isa()) { + if (mlir::isa(srcElemType) && + mlir::isa(destElemType)) { // TosaToLinalg in MLIR uses a fancier algorithm that clamps values to // min/max signed/unsigned integer values. if (destType.isUnsignedInteger()) { @@ -742,7 +747,8 @@ Value MathBuilder::cast(Type destType, Value src) const { } // Int to float conversion. - if (srcElemType.isa() && destElemType.isa()) { + if (mlir::isa(srcElemType) && + mlir::isa(destElemType)) { if (srcElemType.isUnsignedInteger()) { Value cast = castToSignless(src, srcElemWidth); return b().create(loc(), destType, cast); @@ -753,7 +759,7 @@ Value MathBuilder::cast(Type destType, Value src) const { } // Int to int conversion. - if (srcType.isa() && destType.isa()) { + if (mlir::isa(srcType) && mlir::isa(destType)) { if (srcType.isUnsignedInteger()) { // Unsigned to unsigned/signed conversion. // Same bit width for unsigned to signed conversion. @@ -981,7 +987,7 @@ bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, ValueRange dynSymbols, int64_t &staticSize, IndexExpr &dynSize, int64_t range) const { Type elementType = type.getElementType(); - assert(!(elementType.isa()) && "unsupported vector type"); + assert(!(mlir::isa(elementType)) && "unsupported vector type"); ArrayRef shape = type.getShape(); staticSize = 1; // Multiplication of static sizes. dynSize = LiteralIndexExpr(1); // Multiplication of dyn sizes. @@ -1044,7 +1050,7 @@ Value MemRefBuilder::alignedAllocWithSimdPadding(MemRefType type, ValueRange dynSymbols, int64_t VL, int64_t alignment) const { Type elementType = type.getElementType(); assert(!hasNonIdentityLayout(type) && "unsupported layout"); - assert(!(elementType.isa()) && "unsupported vector type"); + assert(!(mlir::isa(elementType)) && "unsupported vector type"); assert(VL >= 1 && "expected positive simd unroll factor"); // Compute total size of memref (in unit of element type). int64_t staticSize; @@ -1169,7 +1175,8 @@ memref::ReshapeOp MemRefBuilder::reshape( create.affine.store(destDims[d].getValue(), outputShapeInMem, {dd}); } // Create output type. - Type elementType = valToReshape.getType().cast().getElementType(); + Type elementType = + mlir::cast(valToReshape.getType()).getElementType(); MemRefType destType = MemRefType::get(outputShape, elementType); // Perform actual reshape operation return reshape(destType, valToReshape, outputShapeInMem); @@ -1184,7 +1191,7 @@ Value MemRefBuilder::reshapeToFlatInnermost(Value valToReshape, llvm::SmallVectorImpl &flattenedDims, int64_t dimsToFlatten) const { // Parse input. - MemRefType inputType = valToReshape.getType().cast(); + MemRefType inputType = mlir::cast(valToReshape.getType()); assert(!hasNonIdentityLayout(inputType) && "MemRef is not normalized"); int64_t inputRank = inputType.getRank(); // Verify dims has the right number of elements. @@ -1216,7 +1223,7 @@ Value MemRefBuilder::reshapeToFlat2D(Value valToReshape, llvm::SmallVectorImpl &dims, llvm::SmallVectorImpl &flattenedDims, int64_t axis) const { // Parse input. - MemRefType inputType = valToReshape.getType().cast(); + MemRefType inputType = mlir::cast(valToReshape.getType()); assert(!hasNonIdentityLayout(inputType) && "MemRef is not normalized"); int64_t inputRank = inputType.getRank(); // Verify dims has the right number of elements. @@ -1284,7 +1291,7 @@ Value MemRefBuilder::reinterpretCast( IndexExpr::getShape(outputDims, outputShape); IndexExpr::getOpOrFoldResults(sizesIE, sizes); IndexExpr::getOpOrFoldResults(stridesIE, strides); - Type elementType = input.getType().cast().getElementType(); + Type elementType = mlir::cast(input.getType()).getElementType(); MemRefType outputMemRefType = MemRefType::get(outputShape, elementType); if (offset) return b().create( @@ -1297,7 +1304,7 @@ Value MemRefBuilder::reinterpretCast( Value MemRefBuilder::collapseShape( Value input, ArrayRef reassociation) { // Extract input info. - MemRefType inputType = input.getType().cast(); + MemRefType inputType = mlir::cast(input.getType()); assert(inputType && "expected input with memref type"); assert(!hasNonIdentityLayout(inputType) && "collapse only for identity layout at this time"); @@ -1367,7 +1374,7 @@ memref::SubViewOp MemRefBuilder::subView(Value input, IndexExpr::getOpOrFoldResults(stridesIE, strides); SmallVector outputShape; IndexExpr::getShape(sizesIE, outputShape); - MemRefType inputType = input.getType().dyn_cast(); + MemRefType inputType = mlir::dyn_cast(input.getType()); MemRefLayoutAttrInterface layout; MemRefType outputType = MemRefType::get(outputShape, inputType.getElementType(), layout, inputType.getMemorySpace()); @@ -1384,8 +1391,8 @@ Value MemRefBuilder::dim(Value val, int64_t index) const { } Value MemRefBuilder::dim(Value val, Value index) const { - // assert((val.getType().isa() || - // val.getType().isa()) && + // assert((mlir::isa(val.getType()) || + // mlir::isa(val.getType())) && // "memref::DimOp expects input operand to have MemRefType or " // "UnrankedMemRefType"); return Value(b().createOrFold(loc(), val, index)); @@ -1485,7 +1492,7 @@ int64_t VectorBuilder::getMachineVectorLength(const VectorType &vecType) const { } int64_t VectorBuilder::getMachineVectorLength(Value vecValue) const { - VectorType vecType = vecValue.getType().dyn_cast_or_null(); + VectorType vecType = mlir::dyn_cast_or_null(vecValue.getType()); assert(vecType && "expected vector type"); return getMachineVectorLength(vecType.getElementType()); } @@ -1558,7 +1565,7 @@ bool VectorBuilder::isPowerOf2(uint64_t num) const { } uint64_t VectorBuilder::getLengthOf1DVector(Value vec) const { - VectorType vecType = vec.getType().dyn_cast_or_null(); + VectorType vecType = mlir::dyn_cast_or_null(vec.getType()); assert(vecType && "expected a vector type"); auto vecShape = vecType.getShape(); assert(vecShape.size() == 1 && "expected a 1D vector"); diff --git a/src/Dialect/Mlir/IndexExprBuilder.cpp b/src/Dialect/Mlir/IndexExprBuilder.cpp index 16491953a8..f19a0dc8c8 100644 --- a/src/Dialect/Mlir/IndexExprBuilder.cpp +++ b/src/Dialect/Mlir/IndexExprBuilder.cpp @@ -84,7 +84,7 @@ namespace onnx_mlir { // a dependence on ONNX. bool IndexExprBuilder::hasShapeAndRank(Value value) { assert(value && "expected a value"); - ShapedType shapedType = value.getType().dyn_cast_or_null(); + ShapedType shapedType = mlir::dyn_cast_or_null(value.getType()); return shapedType && shapedType.hasRank(); } @@ -98,7 +98,7 @@ void IndexExprBuilder::assertHasShapeAndRank(Value value) { uint64_t IndexExprBuilder::getShapedTypeRank(Value value) { assertHasShapeAndRank(value); // Find shaped type size (rank of 0 is scalar). - return value.getType().cast().getRank(); + return mlir::cast(value.getType()).getRank(); } // Size from 1D attribute array. @@ -129,7 +129,7 @@ IndexExpr IndexExprBuilder::getIntFromArrayAsLiteral( uint64_t size = getArraySize(intAttrArray); if (i >= size) return UndefinedIndexExpr(); - int64_t val = (intAttrArray.getValue()[i]).cast().getInt(); + int64_t val = mlir::cast(intAttrArray.getValue()[i]).getInt(); return LiteralIndexExpr(val); } @@ -364,7 +364,7 @@ bool IndexExprBuilder::isLiteralShape(Value tensorOrMemrefValue) { int64_t IndexExprBuilder::getShape(Value tensorOrMemrefValue, uint64_t i) { uint64_t rank = getShapedTypeRank(tensorOrMemrefValue); assert(i < rank && "expected index smaller than memref rank"); - return tensorOrMemrefValue.getType().cast().getShape()[i]; + return mlir::cast(tensorOrMemrefValue.getType()).getShape()[i]; } // Get index expressions from tensor/memref shape. diff --git a/src/Dialect/Mlir/IndexExprDetail.cpp b/src/Dialect/Mlir/IndexExprDetail.cpp index aa65f6771a..4289917e92 100644 --- a/src/Dialect/Mlir/IndexExprDetail.cpp +++ b/src/Dialect/Mlir/IndexExprDetail.cpp @@ -106,14 +106,14 @@ void IndexExprImpl::initAsLiteral(double const val, const IndexExprKind kind) { static bool getIntegerLiteralFromValue(Value value, int64_t &intLit) { // From lib/Dialect/LinAlg/Transform/Promotion.cpp if (auto constantOp = value.getDefiningOp()) { - if (constantOp.getType().isa()) - intLit = constantOp.getValue().cast().getInt(); + if (mlir::isa(constantOp.getType())) + intLit = mlir::cast(constantOp.getValue()).getInt(); return true; } // Since ConstantIndexOp is a subclass of ConstantOp, not sure if this one is // needed. if (auto constantOp = value.getDefiningOp()) { - if (constantOp.getType().isa()) + if (mlir::isa(constantOp.getType())) intLit = constantOp.value(); return true; } @@ -123,14 +123,15 @@ static bool getIntegerLiteralFromValue(Value value, int64_t &intLit) { static bool getFloatLiteralFromValue(Value value, double &floatLit) { // From lib/Dialect/LinAlg/Transform/Promotion.cpp if (auto constantOp = value.getDefiningOp()) { - if (constantOp.getType().isa()) - floatLit = constantOp.getValue().cast().getValueAsDouble(); + if (mlir::isa(constantOp.getType())) + floatLit = + mlir::cast(constantOp.getValue()).getValueAsDouble(); return true; } // Since ConstantFloatOp is a subclass of ConstantOp, not sure if this one is // needed. if (auto constantOp = value.getDefiningOp()) { - if (constantOp.getType().isa()) + if (mlir::isa(constantOp.getType())) floatLit = constantOp.value().convertToDouble(); return true; } @@ -143,7 +144,7 @@ void IndexExprImpl::initAsKind(Value const val, IndexExprKind const newKind) { assert(val != nullptr && "expected a defined value"); // Check that the value is of the right type. auto type = val.getType(); - bool valIsFloat = (type.isa()); + bool valIsFloat = (mlir::isa(type)); // Questionmark if (newKind == IndexExprKind::Questionmark) { initAsQuestionmark(valIsFloat); @@ -170,24 +171,24 @@ void IndexExprImpl::initAsKind(Value const val, IndexExprKind const newKind) { return; } Value newVal = val; - if (type.isa()) { + if (mlir::isa(type)) { if (newKind != IndexExprKind::Predicate) { // We need to convert the int into an index, since we are dealing with // index expressions. newVal = scope->getRewriter().create( scope->getLoc(), scope->getRewriter().getIndexType(), newVal); } - } else if (type.isa()) { + } else if (mlir::isa(type)) { if (newKind == IndexExprKind::Predicate) { // We need to convert the int into an index, since we are dealing with // index expressions. newVal = scope->getRewriter().create( scope->getLoc(), scope->getRewriter().getI1Type(), newVal); } - } else if (type.isa()) { + } else if (mlir::isa(type)) { assert(newKind != IndexExprKind::Predicate && "float cannot be predicate"); // Assume its a single precision float. - unsigned width = type.cast().getWidth(); + unsigned width = mlir::cast(type).getWidth(); assert(width == 32 && "Index expression only support f32 at this time"); } else { llvm_unreachable("unsupported element type"); @@ -227,9 +228,9 @@ void IndexExprImpl::init(bool newIsDefined, bool newIsIntLit, bool isFloatLit, if (value != nullptr) { // We have a value initialized index expr. Determine if we have an integer // or float expression. - if (value.getType().isa()) { + if (mlir::isa(value.getType())) { // Assume its a single precision float. - unsigned width = value.getType().cast().getWidth(); + unsigned width = mlir::cast(value.getType()).getWidth(); assert(width == 32 && "Index expression only support f32 at this time"); isFloat = true; } diff --git a/src/Dialect/Mlir/VectorMachineSupport.cpp b/src/Dialect/Mlir/VectorMachineSupport.cpp index cd0de91111..f25818e9a8 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.cpp +++ b/src/Dialect/Mlir/VectorMachineSupport.cpp @@ -118,7 +118,7 @@ int64_t Z16VectorMachineSupport::getVectorLength( GenericOps Gop, Type elementType) { int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t abstractVL = VectorMachineSupport::getVectorLength(elementType); - bool isFloat = elementType.isa(); + bool isFloat = mlir::isa(elementType); // Support shared between int and float. switch (Gop) { @@ -192,7 +192,7 @@ int64_t SSE42x86VectorMachineSupport::getVectorLength( GenericOps Gop, mlir::Type elementType) { int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t abstractVL = VectorMachineSupport::getVectorLength(elementType); - bool isFloat = elementType.isa(); + bool isFloat = mlir::isa(elementType); // Support shared between int and float. switch (Gop) { @@ -277,7 +277,7 @@ int64_t NeonVectorMachineSupport::getVectorLength( GenericOps Gop, mlir::Type elementType) { int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t abstractVL = VectorMachineSupport::getVectorLength(elementType); - bool isFloat = elementType.isa(); + bool isFloat = mlir::isa(elementType); // Support shared between int and float. switch (Gop) { diff --git a/src/Dialect/ONNX/DialectBuilder.cpp b/src/Dialect/ONNX/DialectBuilder.cpp index 73537ca1c7..36de69f9a5 100644 --- a/src/Dialect/ONNX/DialectBuilder.cpp +++ b/src/Dialect/ONNX/DialectBuilder.cpp @@ -37,17 +37,17 @@ IntegerAttr OnnxBuilder::getSignedInt64Attr(int64_t n) const { // ============================================================================= Value OnnxBuilder::add(Value A, Value B) const { - assert((A.getType().cast().getElementType() == - B.getType().cast().getElementType()) && + assert((mlir::cast(A.getType()).getElementType() == + mlir::cast(B.getType()).getElementType()) && "A and B must have the same element type"); return createOpAndInferShapes(toTensor(A), toTensor(B)); } Value OnnxBuilder::cast(Value input, IntegerAttr saturate, TypeAttr to) const { Type resultType; - if (input.getType().cast().hasRank()) + if (mlir::cast(input.getType()).hasRank()) resultType = RankedTensorType::get( - input.getType().cast().getShape(), to.getValue()); + mlir::cast(input.getType()).getShape(), to.getValue()); else resultType = UnrankedTensorType::get(to.getValue()); return createTypedOpAndInferShapes( @@ -56,9 +56,9 @@ Value OnnxBuilder::cast(Value input, IntegerAttr saturate, TypeAttr to) const { Value OnnxBuilder::cast(Value input, TypeAttr to) const { Type resultType; - if (input.getType().cast().hasRank()) + if (mlir::cast(input.getType()).hasRank()) resultType = RankedTensorType::get( - input.getType().cast().getShape(), to.getValue()); + mlir::cast(input.getType()).getShape(), to.getValue()); else resultType = UnrankedTensorType::get(to.getValue()); IntegerAttr saturate = nullptr; @@ -130,8 +130,8 @@ void OnnxBuilder::dimGroup(Value input, int axis, int groupID) const { } Value OnnxBuilder::div(Value A, Value B) const { - assert((A.getType().cast().getElementType() == - B.getType().cast().getElementType()) && + assert((mlir::cast(A.getType()).getElementType() == + mlir::cast(B.getType()).getElementType()) && "A and B must have the same element type"); return createOpAndInferShapes(toTensor(A), toTensor(B)); } @@ -172,12 +172,12 @@ Value OnnxBuilder::RMSLayerNorm(Type outputType, Value input, Value scale, Value OnnxBuilder::matmul(Type Y, Value A, Value B, bool useGemm) const { // Gemm only supports rank 2. - bool canUseGemm = useGemm && A.getType().isa() && - A.getType().cast().hasRank() && - (A.getType().cast().getRank() == 2) && - B.getType().isa() && - B.getType().cast().hasRank() && - (B.getType().cast().getRank() == 2); + bool canUseGemm = useGemm && mlir::isa(A.getType()) && + mlir::cast(A.getType()).hasRank() && + (mlir::cast(A.getType()).getRank() == 2) && + mlir::isa(B.getType()) && + mlir::cast(B.getType()).hasRank() && + (mlir::cast(B.getType()).getRank() == 2); auto aValue = toTensor(A); auto bValue = toTensor(B); if (canUseGemm) @@ -203,15 +203,15 @@ Value OnnxBuilder::min(ValueRange inputs) const { } Value OnnxBuilder::mul(Value A, Value B) const { - assert((A.getType().cast().getElementType() == - B.getType().cast().getElementType()) && + assert((mlir::cast(A.getType()).getElementType() == + mlir::cast(B.getType()).getElementType()) && "A and B must have the same element type"); return createOpAndInferShapes(toTensor(A), toTensor(B)); } Value OnnxBuilder::mul(Type resultType, Value A, Value B) const { - assert((A.getType().cast().getElementType() == - B.getType().cast().getElementType()) && + assert((mlir::cast(A.getType()).getElementType() == + mlir::cast(B.getType()).getElementType()) && "A and B must have the same element type"); return createTypedOpAndInferShapes( resultType, toTensor(A), toTensor(B)); @@ -223,7 +223,7 @@ Value OnnxBuilder::pad( Value input, Value pads, Value constantValue, std::string mode) const { Type elementType = getElementType(input.getType()); Type outputType = UnrankedTensorType::get(elementType); - Value constant = constantValue.getType().isa() + Value constant = mlir::isa(constantValue.getType()) ? constantValue : toTensor(constantValue); return createTypedOpAndInferShapes(toTensor(outputType), @@ -356,8 +356,8 @@ Value OnnxBuilder::squeeze(Type outputType, Value data, Value axes) const { } Value OnnxBuilder::sub(Value A, Value B) const { - assert((A.getType().cast().getElementType() == - B.getType().cast().getElementType()) && + assert((mlir::cast(A.getType()).getElementType() == + mlir::cast(B.getType()).getElementType()) && "A and B must have the same element type"); return createOpAndInferShapes(toTensor(A), toTensor(B)); } @@ -383,9 +383,9 @@ Value OnnxBuilder::toTensor(Value input) const { // None input. if (isNoneValue(input)) return input; - if (input.getType().isa()) + if (mlir::isa(input.getType())) return input; - assert(input.getType().isa() && + assert(mlir::isa(input.getType()) && "expect RankedMemref type when not a TensorType"); auto aTensorTy = toTensor(input.getType()); // No shape inference for this op. @@ -395,13 +395,13 @@ Value OnnxBuilder::toTensor(Value input) const { } TensorType OnnxBuilder::toTensor(Type input) const { - if (auto tensorType = input.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(input)) return tensorType; - assert(input.isa() && + assert(mlir::isa(input) && "expect RankedMemref type when not a TensorType"); - auto aTy = input.cast(); + auto aTy = mlir::cast(input); Type elementTy = aTy.getElementType(); - if (elementTy.isa()) { + if (mlir::isa(elementTy)) { elementTy = b().getIntegerType(64); } return RankedTensorType::get(aTy.getShape(), elementTy); @@ -409,15 +409,16 @@ TensorType OnnxBuilder::toTensor(Type input) const { TypeRange OnnxBuilder::toTensors(TypeRange inputs) const { assert(inputs.size() >= 2 && "Expect at least two inputs"); - if (llvm::all_of(inputs, [](Type t) { return (t.isa()); })) + if (llvm::all_of(inputs, [](Type t) { return (mlir::isa(t)); })) return inputs; - assert(llvm::all_of(inputs, [](Type t) { return (t.isa()); }) && - "All inputs expect RankedMemref type when not a TensorType"); + assert(llvm::all_of(inputs, [](Type t) { + return (mlir::isa(t)); + }) && "All inputs expect RankedMemref type when not a TensorType"); llvm::SmallVector resultTypes; for (uint64_t i = 0; i < inputs.size(); ++i) { - ShapedType aTy = inputs[i].cast(); + ShapedType aTy = mlir::cast(inputs[i]); Type elementTy = aTy.getElementType(); - if (elementTy.isa()) { + if (mlir::isa(elementTy)) { elementTy = b().getIntegerType(64); } resultTypes.emplace_back(RankedTensorType::get(aTy.getShape(), elementTy)); @@ -426,11 +427,11 @@ TypeRange OnnxBuilder::toTensors(TypeRange inputs) const { } Value OnnxBuilder::toMemref(Value input) const { - if (input.getType().isa()) + if (mlir::isa(input.getType())) return input; - assert(input.getType().isa() && + assert(mlir::isa(input.getType()) && "expect RankedMemref type when not a TensorType"); - auto aTy = input.getType().cast(); + auto aTy = mlir::cast(input.getType()); auto aTensorTy = MemRefType::get(aTy.getShape(), aTy.getElementType()); // No shape inference for this op. return b() @@ -468,14 +469,15 @@ Value OnnxBuilder::where( Value OnnxBuilder::reshapeToNDim( Value val, int64_t N, bool collapseMostSignificant) const { // Get rank of the original shape and determine if we have anything to do. - int64_t rank = val.getType().cast().getRank(); + int64_t rank = mlir::cast(val.getType()).getRank(); int64_t keep = N - 1; // 1 dim for collapsed dims, keep N -1 from original. assert(rank >= N && "Require rank >= N"); if (rank == N) // No collapse is needed, return self. return val; // Compute types. - ArrayRef inputShape = val.getType().cast().getShape(); + ArrayRef inputShape = + mlir::cast(val.getType()).getShape(); Type elementType = getElementType(val.getType()); Type inputShapeType = RankedTensorType::get({rank}, b().getI64Type()); Type keepShapeType = RankedTensorType::get({keep}, b().getI64Type()); @@ -627,7 +629,7 @@ std::vector OnnxBuilder::foldOrEmitONNXSplitOp( SmallVector splitSizesI64; for (auto t : resultTypes) { convertedTypes.emplace_back(create.onnx.toTensor(t)); - splitSizesI64.emplace_back(t.cast().getShape()[axis]); + splitSizesI64.emplace_back(mlir::cast(t).getShape()[axis]); } Value splitSizes = create.onnx.constantInt64(splitSizesI64); ONNXSplitOp split = rewriter.create(loc, convertedTypes, @@ -693,7 +695,7 @@ Value OnnxBuilder::foldOrEmitONNXTransposeOp( getDenseElementAttrFromConstValue(input)) { SmallVector perm; for (auto permVal : permAttr.getValue()) - perm.emplace_back(permVal.cast().getInt()); + perm.emplace_back(mlir::cast(permVal).getInt()); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ElementsAttr transposedElements = diff --git a/src/Dialect/ONNX/ElementsAttr/BType.cpp b/src/Dialect/ONNX/ElementsAttr/BType.cpp index a33b114888..8073d2a4e2 100644 --- a/src/Dialect/ONNX/ElementsAttr/BType.cpp +++ b/src/Dialect/ONNX/ElementsAttr/BType.cpp @@ -18,15 +18,15 @@ namespace onnx_mlir { BType btypeOfMlirType(Type type) { // clang-format off - if (type.isa()) return BType::DOUBLE; - if (type.isa()) return BType::FLOAT; - if (type.isa()) return BType::FLOAT16; - if (type.isa()) return BType::BFLOAT16; - if (type.isa()) return BType::FLOAT8E4M3FN; - if (type.isa()) return BType::FLOAT8E4M3FNUZ; - if (type.isa()) return BType::FLOAT8E5M2; - if (type.isa()) return BType::FLOAT8E5M2FNUZ; - auto itype = type.cast(); + if (mlir::isa(type)) return BType::DOUBLE; + if (mlir::isa(type)) return BType::FLOAT; + if (mlir::isa(type)) return BType::FLOAT16; + if (mlir::isa(type)) return BType::BFLOAT16; + if (mlir::isa(type)) return BType::FLOAT8E4M3FN; + if (mlir::isa(type)) return BType::FLOAT8E4M3FNUZ; + if (mlir::isa(type)) return BType::FLOAT8E5M2; + if (mlir::isa(type)) return BType::FLOAT8E5M2FNUZ; + auto itype = mlir::cast(type); switch (itype.getWidth()) { case 1: return BType::BOOL; case 8: return itype.isUnsigned() ? BType::UINT8 : BType::INT8; diff --git a/src/Dialect/ONNX/ElementsAttr/BType.hpp b/src/Dialect/ONNX/ElementsAttr/BType.hpp index 721c2018c3..bb337d7455 100644 --- a/src/Dialect/ONNX/ElementsAttr/BType.hpp +++ b/src/Dialect/ONNX/ElementsAttr/BType.hpp @@ -191,10 +191,10 @@ mlir::Type toMlirType(mlir::MLIRContext *ctx) { // llvm_unreachable("not a supported datatype") // when called with BType::STRING or BType::COMPLEX64/128. -// == mlirTypeOfBType(btype, ctx).isa() +// == mlir::isa(mlirTypeOfBType(btype, ctx)) bool isFloatBType(BType); -// == mlirTypeOfBType(btype, ctx).isa() +// == mlir::isa(mlirTypeOfBType(btype, ctx)) bool isIntBType(BType); // == mlirTypeOfBType(btype, ctx).isIntOrFloat() diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index d0ddb790dc..dd1ab9e1f1 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -112,7 +112,7 @@ class DisposableElementsAttr // Allow implicit conversion to ElementsAttr. operator ElementsAttr() const { - return *this ? cast() : nullptr; + return *this ? mlir::cast(*this) : nullptr; } private: diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 9251b49e38..63d2d3f0e9 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -74,9 +74,9 @@ ElementsAttr ElementsAttrBuilder::fromMemoryBuffer( DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( ElementsAttr elements) { - if (auto disposable = elements.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elements)) return disposable; - if (auto dense = elements.dyn_cast()) { + if (auto dense = mlir::dyn_cast(elements)) { if (!disposablePool.isActive()) return nullptr; ElementsProperties props = getElementsProperties(dense); @@ -85,7 +85,7 @@ DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( // Check for race condition where disposablePool became inactive since we // checked, in which case it returns a DenseElementsAttr which we don't // want. - if (auto disposable = created.dyn_cast()) + if (auto disposable = mlir::dyn_cast(created)) return disposable; else return nullptr; @@ -97,9 +97,9 @@ DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( /*static*/ DenseElementsAttr ElementsAttrBuilder::toDenseElementsAttr( ElementsAttr elements) { - if (auto disposable = elements.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elements)) return disposable.toDenseElementsAttr(); - if (auto dense = elements.dyn_cast()) + if (auto dense = mlir::dyn_cast(elements)) return dense; // TODO: consider supporting more ElementsAttr types llvm_unreachable("unexpected ElementsAttr instance"); @@ -148,7 +148,7 @@ bool ElementsAttrBuilder::allEqual( constexpr BType TAG = toBType; return n.narrow() == x; }; - if (auto disposable = lhs.dyn_cast()) { + if (auto disposable = mlir::dyn_cast(lhs)) { if (disposable.isTransformedOrCast()) { ArrayBuffer nums = disposable.getBufferAsWideNums(); return llvm::all_of(nums.get(), [n](WideNum m) { @@ -159,7 +159,7 @@ bool ElementsAttrBuilder::allEqual( auto values = castArrayRef(disposable.getBufferBytes()); return llvm::all_of(values, nEquals); } - } else if (auto dense = lhs.dyn_cast()) { + } else if (auto dense = mlir::dyn_cast(lhs)) { if (dense.isSplat()) { cpptype x = dense.getSplatValue(); return nEquals(x); @@ -526,7 +526,7 @@ ElementsAttr ElementsAttrBuilder::reshape( return create(reshapedType, props.bufferBType, *reshapedStrides, props.buffer, props.transformer); - auto disp = elms.dyn_cast(); + auto disp = mlir::dyn_cast(elms); assert(disp && "reshapeStrides() always succeeds for non-Disposable " "ElementsAttr as strides are always default or splat"); @@ -1076,13 +1076,13 @@ ElementsAttr ElementsAttrBuilder::nonZero(ElementsAttr elms) { auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) -> ElementsProperties { static Transformer nullTransformer = nullptr; - if (auto disposable = elements.dyn_cast()) { + if (auto disposable = mlir::dyn_cast(elements)) { ArrayRef strides = disposable.getStrides(); return {/*.bufferBType=*/disposable.getBufferBType(), /*.strides=*/{strides.begin(), strides.end()}, /*.buffer=*/disposable.getBuffer(), /*.transformer=*/disposable.getTransformer()}; - } else if (auto dense = elements.dyn_cast()) { + } else if (auto dense = mlir::dyn_cast(elements)) { ShapedType type = dense.getType(); SmallVector strides; if (dense.isSplat()) { @@ -1103,7 +1103,7 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) ArrayBuffer ElementsAttrBuilder::getWideNumsAndExpandedStrides( ElementsAttr elms, llvm::ArrayRef expandedShape, llvm::SmallVectorImpl &expandedStrides) { - if (auto disposable = elms.dyn_cast()) { + if (auto disposable = mlir::dyn_cast(elms)) { expandedStrides = expandStrides(disposable.getStrides(), expandedShape); return disposable.getBufferAsWideNums(); } else if (elms.isSplat()) { diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.cpp index fdb29d4c62..ded0661386 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.cpp @@ -23,7 +23,7 @@ using namespace mlir; namespace onnx_mlir { WideNum getElementsSplatWideNum(ElementsAttr elms) { - if (auto disposable = elms.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elms)) return disposable.getSplatValue(); Type elementType = elms.getElementType(); if (isa(elementType)) @@ -57,11 +57,11 @@ void readDenseElementsWideNums( // everything aligns, otherwise makes and returns a copy. // Precondition: elms.getElementType.isIntOrFloat(). ArrayBuffer getElementsWideNums(ElementsAttr elms) { - if (auto disposable = elms.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elms)) return disposable.getWideNums(); // Return raw data if non-splat DenseElementsAttr and element type is wide. - if (auto dense = elms.dyn_cast()) { + if (auto dense = mlir::dyn_cast(elms)) { auto isWideType = [](Type t) { return t.isInteger(64) || t.isF64(); }; if (isWideType(dense.getElementType()) && !dense.isSplat()) return castArrayRef(dense.getRawData()); @@ -76,7 +76,7 @@ ArrayBuffer getElementsWideNums(ElementsAttr elms) { // Copies out the elements in a flat WideNum array in row-major order. // Precondition: elms.getElementType.isIntOrFloat(). void readElementsWideNums(ElementsAttr elms, MutableArrayRef dst) { - if (auto disposable = elms.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elms)) return disposable.readWideNums(dst); assert(dst.size() == static_cast(elms.size())); readDenseElementsWideNums(elms, dst); diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp.inc b/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp.inc index 241265bd4b..3d86b02c22 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp.inc +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp.inc @@ -6,13 +6,13 @@ template ArrayBuffer getElementsArray(mlir::ElementsAttr elms) { - if (auto disposable = elms.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elms)) return disposable.getArray(); if (elms.isSplat()) return typename ArrayBuffer::Vector(elms.size(), elms.getSplatValue()); if (!elms.getElementType().isInteger(1)) { - if (auto dense = elms.dyn_cast()) { + if (auto dense = mlir::dyn_cast(elms)) { llvm::ArrayRef data = castArrayRef(dense.getRawData()); assert(data.size() == static_cast(elms.size())); return data; @@ -25,14 +25,14 @@ ArrayBuffer getElementsArray(mlir::ElementsAttr elms) { template void readElementsArray(mlir::ElementsAttr elms, llvm::MutableArrayRef dst) { - if (auto disposable = elms.dyn_cast()) + if (auto disposable = mlir::dyn_cast(elms)) return disposable.readArray(dst); if (elms.isSplat()) { assert(dst.size() == static_cast(elms.size())); return std::fill(dst.begin(), dst.end(), elms.getSplatValue()); } if (!elms.getElementType().isInteger(1)) { - if (auto dense = elms.dyn_cast()) { + if (auto dense = mlir::dyn_cast(elms)) { llvm::ArrayRef data = castArrayRef(dense.getRawData()); auto end = std::copy(data.begin(), data.end(), dst.begin()); assert(end == dst.end()); diff --git a/src/Dialect/ONNX/ElementsAttr/WideNum.hpp b/src/Dialect/ONNX/ElementsAttr/WideNum.hpp index e9bd2cb779..bc3adde34a 100644 --- a/src/Dialect/ONNX/ElementsAttr/WideNum.hpp +++ b/src/Dialect/ONNX/ElementsAttr/WideNum.hpp @@ -202,9 +202,9 @@ auto wideZeroDispatch(mlir::Type type, Action &&act); template auto wideZeroDispatchNonBool(mlir::Type type, Action &&act) { - if (type.isa()) + if (mlir::isa(type)) return act(static_cast(0)); - auto itype = type.cast(); + auto itype = mlir::cast(type); if (itype.isUnsigned()) return act(static_cast(0)); else diff --git a/src/Dialect/ONNX/ONNX.td b/src/Dialect/ONNX/ONNX.td index d9f7259ee1..53c52a3ffc 100644 --- a/src/Dialect/ONNX/ONNX.td +++ b/src/Dialect/ONNX/ONNX.td @@ -118,17 +118,17 @@ def ONNXTensorEncodingAttr : ONNX_LayoutAttr<"ONNXTensorEncoding"> { class ONNXCustomDataLayoutAndFactorsOfPred< string layout, int xFactor, int yFactor> : And<[ - CPred<"($_self.cast<::mlir::RankedTensorType>()) &&" - "($_self.cast<::mlir::RankedTensorType>()." - "getEncoding().dyn_cast_or_null()) &&" - "($_self.cast<::mlir::RankedTensorType>()." - "getEncoding().cast().getDataLayout()" + CPred<"(mlir::cast<::mlir::RankedTensorType>($_self)) &&" + "(mlir::dyn_cast_or_null" + "(mlir::cast<::mlir::RankedTensorType>($_self).getEncoding())) &&" + "(mlir::cast(mlir::cast<::mlir::RankedTensorType>($_self)" + ".getEncoding()).getDataLayout()" " == ONNXTensorEncodingAttr::DataLayout::" # layout # ") &&" - "($_self.cast<::mlir::RankedTensorType>()." - "getEncoding().cast().getXFactor()" + "(mlir::cast(mlir::cast<::mlir::RankedTensorType>($_self)" + ".getEncoding()).getXFactor()" " == " # xFactor # ") &&" - "($_self.cast<::mlir::RankedTensorType>()." - "getEncoding().cast().getYFactor()" + "(mlir::cast(mlir::cast<::mlir::RankedTensorType>($_self)" + ".getEncoding()).getYFactor()" " == " # yFactor # ")"> ]>; @@ -203,19 +203,19 @@ def ONNX_OptType : ONNX_Type<"Opt", "Opt"> { // Can be used in other table gen files (.td) for onnx dialect //===----------------------------------------------------------------------===// -def StringType : Type()">, "string type">; +def StringType : Type($_self)">, "string type">; -def IsSeqTypePred : CPred<"$_self.isa()">; +def IsSeqTypePred : CPred<"mlir::isa($_self)">; -class SeqOf allowedTypes> : +class SeqOf allowedTypes> : ContainerType, IsSeqTypePred, - "$_self.cast().getElementType()", "SeqType">; + "mlir::cast($_self).getElementType()", "SeqType">; -def IsOptTypePred : CPred<"$_self.isa()">; +def IsOptTypePred : CPred<"mlir::isa($_self)">; class OptOf : ContainerType().getElementType()", "OptType">; + "mlir::cast($_self).getElementType()", "OptType">; def ONNXConstantOpFromDenseAttr: NativeCodeCall< "onnx_mlir::OnnxBuilder($_builder, $_loc).constant($0)">; @@ -246,7 +246,7 @@ class ONNX_Op traits = []> : // 1. Attributes are not processed // 2. output type inference not implemented except Add // 3. Type Attribute: 'optional' and 'Variadic hetergeneous' are ignored -// 4. type of string, complex64 and complex128 for input/output are ignored +// 4. type of string, complex64 and complex128 for input/output are ignored // 5. unsigned int are treated as signed one include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/src/Dialect/ONNX/ONNXAttributes.cpp b/src/Dialect/ONNX/ONNXAttributes.cpp index 48fdc4359e..b4cac12687 100644 --- a/src/Dialect/ONNX/ONNXAttributes.cpp +++ b/src/Dialect/ONNX/ONNXAttributes.cpp @@ -65,7 +65,7 @@ Attribute ONNXTensorEncodingAttr::parse(AsmParser &parser, Type type) { // Process the data from the parsed dictionary value into struct-like data. for (const NamedAttribute &attr : dict) { if (attr.getName() == "dataLayout") { - StringAttr layoutAttr = attr.getValue().dyn_cast(); + StringAttr layoutAttr = mlir::dyn_cast(attr.getValue()); if (!layoutAttr) { parser.emitError( parser.getNameLoc(), "expected a string value for data layout"); @@ -124,7 +124,7 @@ Attribute ONNXDialect::parseAttribute( generatedAttributeParser(parser, &attrTag, type, attr).has_value()) return attr; if (attrTag == DisposableElementsAttr::getMnemonic()) { - auto shapedTy = type.cast(); + auto shapedTy = mlir::cast(type); if (auto membuf = DisposableElementsAttr::parse(parser, shapedTy)) return OnnxElementsAttrBuilder(type.getContext()) .fromMemoryBuffer(shapedTy, std::move(membuf)); @@ -142,6 +142,6 @@ void ONNXDialect::printAttribute( // generatedAttributePrinter is generated in ONNXAttributes.cpp.inc if (succeeded(generatedAttributePrinter(attr, printer))) return; - if (auto elements = attr.dyn_cast()) + if (auto elements = mlir::dyn_cast(attr)) elements.printWithoutType(printer); } diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index c1b06859d8..62e5b51e81 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -71,7 +71,7 @@ static std::optional insertDimWhenUseful(const Value tensor, uint64_t axis = dimIndex; bool okToInsert = false; - if (tensor.isa()) { + if (mlir::isa(tensor)) { okToInsert = true; } else { Operation *op = tensor.getDefiningOp(); @@ -550,7 +550,7 @@ void DimAnalysis::buildFunctionArgsRes(func::FuncOp funcOp) { auto buildFor = [¶mSetMap, this](ValueRange args, ArrayAttr argAttrs) { for (size_t argPos = 0; argPos < args.size(); ++argPos) { Value arg = args[argPos]; - auto tensorType = arg.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(arg.getType()); if (!tensorType) continue; // Get dim_params if exists. @@ -603,7 +603,7 @@ void DimAnalysis::buildFunctionArgsRes(func::FuncOp funcOp) { } void DimAnalysis::build(Value val) { - if (auto tensorType = val.getType().dyn_cast()) { + if (auto tensorType = mlir::dyn_cast(val.getType())) { for (unsigned i = 0; i < tensorType.getRank(); ++i) { // Only care about dynamic dimensions. if (tensorType.isDynamicDim(i)) { @@ -624,8 +624,8 @@ void DimAnalysis::build(Value val) { bool DimAnalysis::sameDim( Value tensor1, int64_t dimAxis1, Value tensor2, int64_t dimAxis2) const { // Handle negative axis and test if in bound. - ShapedType tensor1Type = tensor1.getType().cast(); - ShapedType tensor2Type = tensor2.getType().cast(); + ShapedType tensor1Type = mlir::cast(tensor1.getType()); + ShapedType tensor2Type = mlir::cast(tensor2.getType()); if (!handleAndTestInBound(dimAxis1, tensor1Type) || !handleAndTestInBound(dimAxis2, tensor2Type)) return false; @@ -647,8 +647,8 @@ bool DimAnalysis::sameDim( bool DimAnalysis::sameDynDim( Value tensor1, int64_t dimAxis1, Value tensor2, int64_t dimAxis2) const { // Handle negative axis and test if in bound. - ShapedType tensor1Type = tensor1.getType().cast(); - ShapedType tensor2Type = tensor2.getType().cast(); + ShapedType tensor1Type = mlir::cast(tensor1.getType()); + ShapedType tensor2Type = mlir::cast(tensor2.getType()); if (!handleAndTestInBound(dimAxis1, tensor1Type) || !handleAndTestInBound(dimAxis2, tensor2Type)) return false; @@ -669,7 +669,7 @@ bool DimAnalysis::sameDynDim( bool DimAnalysis::sameShape(Value tensor1, Value tensor2) const { if (!sameRank(tensor1, tensor2)) return false; - unsigned rank = tensor1.getType().cast().getRank(); + unsigned rank = mlir::cast(tensor1.getType()).getRank(); // Check each dimension. for (unsigned i = 0; i < rank; ++i) { if (!sameDim(tensor1, i, tensor2, i)) @@ -681,8 +681,10 @@ bool DimAnalysis::sameShape(Value tensor1, Value tensor2) const { bool DimAnalysis::sameDynShape(Value tensor1, Value tensor2) const { if (!sameRank(tensor1, tensor2)) return false; - ArrayRef shape1 = tensor1.getType().cast().getShape(); - ArrayRef shape2 = tensor2.getType().cast().getShape(); + ArrayRef shape1 = + mlir::cast(tensor1.getType()).getShape(); + ArrayRef shape2 = + mlir::cast(tensor2.getType()).getShape(); // Check each dimension. for (unsigned i = 0; i < shape1.size(); ++i) { int64_t dim1 = shape1[i]; @@ -700,7 +702,8 @@ bool DimAnalysis::sameDynShape(Value tensor1, Value tensor2) const { bool DimAnalysis::broadcastLastDim(Value tensor1, Value tensor2) const { if (!sameRank(tensor1, tensor2)) return false; - ArrayRef shape1 = tensor1.getType().cast().getShape(); + ArrayRef shape1 = + mlir::cast(tensor1.getType()).getShape(); unsigned rank = shape1.size(); // The last dimension of tensor1 must be 1, so that tensor1 is broadcasting // to tensor2. @@ -738,10 +741,8 @@ void DimAnalysis::getONNXDimParams( DictionaryAttr dictAttr = llvm::dyn_cast(argResAttr[index]); if (dictAttr && dictAttr.contains("onnx.dim_params")) { // onnx.dim_params = dimIndex:dimParam,dimIndex:dimParam,... - StringRef dimParams = dictAttr.getNamed("onnx.dim_params") - .value() - .getValue() - .cast() + StringRef dimParams = mlir::cast( + dictAttr.getNamed("onnx.dim_params").value().getValue()) .getValue(); SmallVector splittedDimParams; dimParams.split(splittedDimParams, ','); @@ -850,7 +851,7 @@ void DimAnalysis::visitDim( // operation's shape helper for this purpose as much as possible. // Tensor is a block argument. There is no defining operator to explore. - if (tensor.isa()) + if (mlir::isa(tensor)) return; // Get the defining operator. @@ -877,7 +878,7 @@ void DimAnalysis::visitDim( // All dimensions in the analysis must be dynamic. If not, something really // wrong happened. - ShapedType ty = tensor.getType().cast(); + ShapedType ty = mlir::cast(tensor.getType()); assert(ty.isDynamicDim(dimIndex) && "There is a static dim in the analysis. " "Something really wrong happened."); @@ -1062,7 +1063,7 @@ void ONNXDimAnalysisPass::runOnOperation() { Value val = ti.first; uint64_t dimAxis = ti.second; Location loc = val.getLoc(); - if (auto arg = val.dyn_cast()) { + if (auto arg = mlir::dyn_cast(val)) { Block *owner = arg.getOwner(); b = OpBuilder::atBlockBegin(owner); } else { diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 6ddc427b4f..892cf01cf1 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -46,21 +46,21 @@ using namespace onnx_mlir; //===----------------------------------------------------------------------===// Type getBroadcastedRankedType( Type type1, Type type2, Type elementType = nullptr) { - if (type1.isa() && type2.isa()) + if (mlir::isa(type1) && mlir::isa(type2)) return OpTrait::util::getBroadcastedType(type1, type2, elementType); - if (type1.isa() && type2.isa()) { + if (mlir::isa(type1) && mlir::isa(type2)) { // Construct RankedTensorType(s). if (!elementType) - elementType = type1.cast().getElementType(); - RankedTensorType ty1 = - RankedTensorType::get(type1.cast().getShape(), elementType); - RankedTensorType ty2 = - RankedTensorType::get(type2.cast().getShape(), elementType); + elementType = mlir::cast(type1).getElementType(); + RankedTensorType ty1 = RankedTensorType::get( + mlir::cast(type1).getShape(), elementType); + RankedTensorType ty2 = RankedTensorType::get( + mlir::cast(type2).getShape(), elementType); // Compute a broadcasted type. Type outputType = OpTrait::util::getBroadcastedType(ty1, ty2); // Construct a MemRefType. return MemRefType::get( - outputType.cast().getShape(), elementType); + mlir::cast(outputType).getShape(), elementType); } else return {}; } @@ -86,7 +86,7 @@ namespace { // DisposableElementsAttr is an internal representation, so we hide it // in this way. void printAttribute(OpAsmPrinter &printer, Attribute attr) { - if (auto disposable = attr.dyn_cast()) + if (auto disposable = mlir::dyn_cast(attr)) disposable.printAsDenseElementsAttr(printer); else printer.printAttribute(attr); @@ -97,7 +97,7 @@ void printNamedAttribute(OpAsmPrinter &printer, NamedAttribute namedAttr) { printer.printKeywordOrString(namedAttr.getName().strref()); // Pretty printing elides the attribute value for unit attributes. - if (namedAttr.getValue().isa()) + if (mlir::isa(namedAttr.getValue())) return; printer << " = "; @@ -152,8 +152,8 @@ void ONNXConstantOp::print(OpAsmPrinter &printer) { Type resultType = getResult().getType(); if (auto attr = getValue()) { // ONNXConstantOp value must be ElementsAttr, but not SparseElementsAttr. - auto elements = attr->cast(); - assert(!elements.isa() && + auto elements = mlir::cast(*attr); + assert(!mlir::isa(elements) && "ONNXConstantOp value cannot be sparse"); if (elements.getType() == resultType) { printer << ' '; @@ -163,7 +163,7 @@ void ONNXConstantOp::print(OpAsmPrinter &printer) { } if (auto attr = getSparseValue()) { // ONNXConstantOp sparse_value must be SparseElementsAttr. - auto sparseElements = attr->cast(); + auto sparseElements = mlir::cast(*attr); if (sparseElements.getType() == resultType) { printer << ' '; printer.printAttribute(sparseElements); @@ -191,9 +191,9 @@ ParseResult ONNXConstantOp::parse(OpAsmParser &parser, OperationState &result) { if (*opt) return failure(); const char *name = - attr.isa() ? "sparse_value" : "value"; + mlir::isa(attr) ? "sparse_value" : "value"; result.addAttribute(name, attr); - result.addTypes({attr.cast().getType()}); + result.addTypes({mlir::cast(attr).getType()}); return success(); } // No sparse_value or value attr, so attribute dictionary really is empty. diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index aee877be5b..61f3687366 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -16,11 +16,11 @@ def ONNXAbsOp:ONNX_Op<"Abs", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$Y); let builders = [ OpBuilder<(ins "Value":$X), [{ - auto resultType = UnrankedTensorType::get(X.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(X.getType()).getElementType()); build($_builder, $_state, resultType, X); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -122,18 +122,18 @@ def ONNXAddOp:ONNX_Op<"Add", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -177,18 +177,18 @@ def ONNXAndOp:ONNX_Op<"And", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -1348,11 +1348,11 @@ def ONNXConstantOp:ONNX_Op<"Constant", let builders = [ OpBuilder<(ins "Attribute":$sparse_value, "Attribute":$value), [{ if (value) { - auto tensorType = value.cast().getType(); + auto tensorType = mlir::cast(value).getType(); build($_builder, $_state, tensorType, sparse_value, value, FloatAttr(), ArrayAttr(), IntegerAttr(), ArrayAttr(), StringAttr(), ArrayAttr()); } else { - auto tensorType = sparse_value.cast().getType(); + auto tensorType = mlir::cast(sparse_value).getType(); build($_builder, $_state, tensorType, sparse_value, value, FloatAttr(), ArrayAttr(), IntegerAttr(), ArrayAttr(), StringAttr(), ArrayAttr()); } @@ -1412,11 +1412,11 @@ def ONNXConvOp:ONNX_Op<"Conv", let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let builders = [ OpBuilder<(ins "Value":$X, "Value":$W, "Value":$B, "StringAttr":$auto_pad, "ArrayAttr":$dilations, "IntegerAttr":$group, "ArrayAttr":$kernel_shape, "ArrayAttr":$pads, "ArrayAttr":$strides), [{ - auto resultType = UnrankedTensorType::get(X.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(X.getType()).getElementType()); build($_builder, $_state, resultType, X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -1911,18 +1911,18 @@ def ONNXDivOp:ONNX_Op<"Div", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -2157,7 +2157,7 @@ def ONNXEqualOp:ONNX_Op<"Equal", auto rhsTy = B.getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, resultType, A, B); @@ -2167,7 +2167,7 @@ def ONNXEqualOp:ONNX_Op<"Equal", auto rhsTy = operands[1].getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, {resultType}, operands, attributes); @@ -2234,11 +2234,11 @@ def ONNXExpOp:ONNX_Op<"Exp", let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$output); let builders = [ OpBuilder<(ins "Value":$input), [{ - auto resultType = UnrankedTensorType::get(input.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(input.getType()).getElementType()); build($_builder, $_state, resultType, input); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -2966,7 +2966,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater", auto rhsTy = B.getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, resultType, A, B); @@ -2976,7 +2976,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater", auto rhsTy = operands[1].getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, {resultType}, operands, attributes); @@ -3022,7 +3022,7 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", auto rhsTy = B.getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, resultType, A, B); @@ -3032,7 +3032,7 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", auto rhsTy = operands[1].getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, {resultType}, operands, attributes); @@ -3326,11 +3326,11 @@ def ONNXIdentityOp:ONNX_Op<"Identity", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex]>]>, SeqOf<[TensorOf<[Complex]>]>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>>, OptOf]>]>>, OptOf]>]>>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf>, OptOf]>>, OptOf]>>]>:$output); let builders = [ OpBuilder<(ins "Value":$input), [{ - auto resultType = UnrankedTensorType::get(input.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(input.getType()).getElementType()); build($_builder, $_state, resultType, input); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -3755,7 +3755,7 @@ def ONNXLessOp:ONNX_Op<"Less", auto rhsTy = B.getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, resultType, A, B); @@ -3765,7 +3765,7 @@ def ONNXLessOp:ONNX_Op<"Less", auto rhsTy = operands[1].getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, {resultType}, operands, attributes); @@ -3811,7 +3811,7 @@ def ONNXLessOrEqualOp:ONNX_Op<"LessOrEqual", auto rhsTy = B.getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, resultType, A, B); @@ -3821,7 +3821,7 @@ def ONNXLessOrEqualOp:ONNX_Op<"LessOrEqual", auto rhsTy = operands[1].getType(); auto elTy = $_builder.getI1Type(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy, elTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) resultType = UnrankedTensorType::get(elTy); build($_builder, $_state, {resultType}, operands, attributes); @@ -4661,18 +4661,18 @@ def ONNXMulOp:ONNX_Op<"Mul", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -4743,11 +4743,11 @@ def ONNXNegOp:ONNX_Op<"Neg", let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$Y); let builders = [ OpBuilder<(ins "Value":$X), [{ - auto resultType = UnrankedTensorType::get(X.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(X.getType()).getElementType()); build($_builder, $_state, resultType, X); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -5173,18 +5173,18 @@ def ONNXOrOp:ONNX_Op<"Or", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -5357,11 +5357,11 @@ def ONNXPadOp:ONNX_Op<"Pad", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$output); let builders = [ OpBuilder<(ins "Value":$data, "Value":$pads, "Value":$constant_value, "Value":$axes, "StringAttr":$mode), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, pads, constant_value, axes, mode); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -5771,18 +5771,18 @@ def ONNXPowOp:ONNX_Op<"Pow", auto lhsTy = X.getType(); auto rhsTy = Y.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, X, Y); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -6448,11 +6448,11 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes, "IntegerAttr":$keepdims, "IntegerAttr":$noop_with_empty_axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims, noop_with_empty_axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -6614,11 +6614,11 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>, TensorOf<[UI8]>, TensorOf<[I8]>, TensorOf<[I1]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes, "IntegerAttr":$keepdims, "IntegerAttr":$noop_with_empty_axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims, noop_with_empty_axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -6663,11 +6663,11 @@ def ONNXReduceMaxV18Op:ONNX_Op<"ReduceMaxV18", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>, TensorOf<[UI8]>, TensorOf<[I8]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes, "IntegerAttr":$keepdims, "IntegerAttr":$noop_with_empty_axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims, noop_with_empty_axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -6711,11 +6711,11 @@ def ONNXReduceMaxV13Op:ONNX_Op<"ReduceMaxV13", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>, TensorOf<[UI8]>, TensorOf<[I8]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "ArrayAttr":$axes, "IntegerAttr":$keepdims), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -7032,11 +7032,11 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes, "IntegerAttr":$keepdims, "IntegerAttr":$noop_with_empty_axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims, noop_with_empty_axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -7078,11 +7078,11 @@ def ONNXReduceSumV11Op:ONNX_Op<"ReduceSumV11", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "ArrayAttr":$axes, "IntegerAttr":$keepdims), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -7127,11 +7127,11 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", let results = (outs AnyTypeOf<[TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$reduced); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes, "IntegerAttr":$keepdims, "IntegerAttr":$noop_with_empty_axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes, keepdims, noop_with_empty_axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -8749,11 +8749,11 @@ def ONNXSoftmaxOp:ONNX_Op<"Softmax", let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$output); let builders = [ OpBuilder<(ins "Value":$input, "IntegerAttr":$axis), [{ - auto resultType = UnrankedTensorType::get(input.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(input.getType()).getElementType()); build($_builder, $_state, resultType, input, axis); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9006,11 +9006,11 @@ def ONNXSplitOp:ONNX_Op<"Split", let results = (outs Variadic, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>>:$outputs); let builders = [ OpBuilder<(ins "Value":$input, "Value":$split, "IntegerAttr":$axis, "IntegerAttr":$num_outputs), [{ - auto resultType = UnrankedTensorType::get(input.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(input.getType()).getElementType()); build($_builder, $_state, resultType, input, split, axis, num_outputs); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9050,11 +9050,11 @@ def ONNXSplitV13Op:ONNX_Op<"SplitV13", let results = (outs Variadic, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>>:$outputs); let builders = [ OpBuilder<(ins "Value":$input, "Value":$split, "IntegerAttr":$axis), [{ - auto resultType = UnrankedTensorType::get(input.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(input.getType()).getElementType()); build($_builder, $_state, resultType, input, split, axis); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9168,11 +9168,11 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$Y); let builders = [ OpBuilder<(ins "Value":$X), [{ - auto resultType = UnrankedTensorType::get(X.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(X.getType()).getElementType()); build($_builder, $_state, resultType, X); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9212,11 +9212,11 @@ def ONNXSqueezeOp:ONNX_Op<"Squeeze", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$squeezed); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9257,11 +9257,11 @@ def ONNXSqueezeV11Op:ONNX_Op<"SqueezeV11", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$squeezed); let builders = [ OpBuilder<(ins "Value":$data, "ArrayAttr":$axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9347,18 +9347,18 @@ def ONNXSubOp:ONNX_Op<"Sub", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9879,11 +9879,11 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$expanded); let builders = [ OpBuilder<(ins "Value":$data, "Value":$axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -9931,11 +9931,11 @@ def ONNXUnsqueezeV11Op:ONNX_Op<"UnsqueezeV11", let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$expanded); let builders = [ OpBuilder<(ins "Value":$data, "ArrayAttr":$axes), [{ - auto resultType = UnrankedTensorType::get(data.getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(data.getType()).getElementType()); build($_builder, $_state, resultType, data, axes); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ - auto resultType = UnrankedTensorType::get(operands[0].getType().cast().getElementType()); + auto resultType = UnrankedTensorType::get(mlir::cast(operands[0].getType()).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; @@ -10083,18 +10083,18 @@ def ONNXXorOp:ONNX_Op<"Xor", auto lhsTy = A.getType(); auto rhsTy = B.getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, resultType, A, B); }]>, OpBuilder<(ins "ValueRange":$operands, "ArrayRef":$attributes), [{ auto lhsTy = operands[0].getType(); auto rhsTy = operands[1].getType(); auto resultType = getBroadcastedRankedType(lhsTy, rhsTy); - auto shapedType = resultType.dyn_cast_or_null(); + auto shapedType = mlir::dyn_cast_or_null(resultType); if (!shapedType || !shapedType.hasStaticShape()) - resultType = UnrankedTensorType::get(lhsTy.cast().getElementType()); + resultType = UnrankedTensorType::get(mlir::cast(lhsTy).getElementType()); build($_builder, $_state, {resultType}, operands, attributes); }]> ]; diff --git a/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp b/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp index fa3f4b025f..fdf6991771 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp @@ -55,7 +55,7 @@ LogicalResult ONNXConcatShapeTransposeOpShapeHelper::computeShape() { unsigned numInputs = concatOp.getNumOperands(); Value firstInput = operandAdaptor.getInputs().front(); ArrayRef commonShape = - firstInput.getType().cast().getShape(); + mlir::cast(firstInput.getType()).getShape(); int64_t commonRank = commonShape.size(); int64_t axisIndex = concatOp.getAxis(); @@ -154,8 +154,8 @@ LogicalResult ONNXConcatShapeTransposeOp::inferShapes( // If any input is not ranked tensor, do nothing. if (!hasShapeAndRank(getOperation())) return success(); - auto commonType = getOperand(0).getType().cast(); - Type intType = IntegerType::get(getContext(), 64).cast(); + auto commonType = mlir::cast(getOperand(0).getType()); + Type intType = mlir::cast(IntegerType::get(getContext(), 64)); SmallVector elementTypes = {intType, commonType.getElementType()}; ONNXConcatShapeTransposeOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateTypes(elementTypes); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Custom.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Custom.cpp index 90a5a4ae52..00863e8006 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/Custom.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/Custom.cpp @@ -44,7 +44,7 @@ LogicalResult ONNXCustomOp::inferShapes( std::optional inputIndexAttrs = getInputsForInfer(); int64_t inputIdx = 0; if (inputIndexAttrs.has_value()) - inputIdx = (inputIndexAttrs->getValue()[0]).cast().getInt(); + inputIdx = mlir::cast(inputIndexAttrs->getValue()[0]).getInt(); Type elementType = getOutputElementType().value_or( getElementType(getInputs()[inputIdx].getType())); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/LayoutTransform.cpp b/src/Dialect/ONNX/ONNXOps/Additional/LayoutTransform.cpp index cc35a2a095..19744eaff5 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/LayoutTransform.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/LayoutTransform.cpp @@ -28,7 +28,7 @@ LogicalResult ONNXLayoutTransformOp::inferShapes( return success(); Type elementType = - getData().getType().dyn_cast().getElementType(); + mlir::dyn_cast(getData().getType()).getElementType(); ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType( elementType, getTargetLayoutAttr()); @@ -38,8 +38,9 @@ LogicalResult ONNXLayoutTransformOp::inferShapes( // Verifier //===----------------------------------------------------------------------===// LogicalResult ONNXLayoutTransformOp::verify() { - if (auto dataType = getData().getType().dyn_cast()) { - if (auto outputType = getOutput().getType().dyn_cast()) { + if (auto dataType = mlir::dyn_cast(getData().getType())) { + if (auto outputType = + mlir::dyn_cast(getOutput().getType())) { for (int64_t i = 0; i < dataType.getRank(); ++i) { // Check if there is an unknown dimension in the dataShape and // outputShape. If there is an unknown dimension, we will return true. diff --git a/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp b/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp index 303fbf8046..26a9e5cd54 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp @@ -28,7 +28,7 @@ LogicalResult ONNXShapeTransformOpShapeHelper::computeShape() { Value input = operandAdaptor.getInput(); AffineMap indexMap = operandAdaptor.getIndexMap(); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); ArrayRef inputDims = inputType.getShape(); int64_t outputRank = indexMap.getNumResults(); @@ -73,7 +73,7 @@ LogicalResult ONNXShapeTransformOp::inferShapes( if (!hasShapeAndRank(op)) return success(); // Input and output have the same element type and encoding. - auto inputType = getOperand().getType().cast(); + auto inputType = mlir::cast(getOperand().getType()); ONNXShapeTransformOpShapeHelper shapeHelper(op, {}); return shapeHelper.computeShapeAndUpdateTypes( inputType.getElementType(), inputType.getEncoding()); diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index 5e4832bec9..fa6c23030b 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -43,7 +43,7 @@ namespace onnx_mlir { // If 'A' is NoneType, return -B. Otherwise return A-B. Value subtractOrNeg(PatternRewriter &rewriter, Location loc, Value A, Value B) { - if (A.getType().isa()) + if (mlir::isa(A.getType())) return rewriter.create(loc, B); return rewriter.create(loc, A, B); } @@ -67,12 +67,14 @@ DenseElementsAttr createDenseElementsAttrOfNToM( // Get return type for a MatMulOp whose A's rank is N (>2) and B's rank is 2. Type getReturnTypeForMatMulOpND2D(Value A, Value B) { - ArrayRef aShape = A.getType().cast().getShape(); - ArrayRef bShape = B.getType().cast().getShape(); + ArrayRef aShape = + mlir::cast(A.getType()).getShape(); + ArrayRef bShape = + mlir::cast(B.getType()).getShape(); SmallVector resShape(aShape.begin(), aShape.end() - 1); resShape.emplace_back(bShape[bShape.size() - 1]); return RankedTensorType::get( - resShape, A.getType().cast().getElementType()); + resShape, mlir::cast(A.getType()).getElementType()); } // Get the index of the axis value in the given permutation array. @@ -80,7 +82,7 @@ IntegerAttr getIndexOfAxisInPerm( PatternRewriter &rewriter, ArrayAttr permAttr, IntegerAttr axis) { IntegerAttr result; for (uint64_t i = 0; i < permAttr.getValue().size(); ++i) { - IntegerAttr attr = permAttr.getValue()[i].cast(); + IntegerAttr attr = mlir::cast(permAttr.getValue()[i]); assert(attr && "Element in ArrayAttr is not IntegerAttr"); if (attr.getValue().getSExtValue() == axis.getValue().getSExtValue()) return rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), i); @@ -93,7 +95,7 @@ SmallVector transposeVariadicInput(PatternRewriter &rewriter, Location loc, ValueRange inputs, ArrayAttr permAttr) { SmallVector transposedInputs; for (Value inp : inputs) { - ShapedType inpType = inp.getType().cast(); + ShapedType inpType = mlir::cast(inp.getType()); assert(inpType && "Type is not ShapedType"); ONNXTransposeOp transposeOp = rewriter.create( loc, UnrankedTensorType::get(inpType.getElementType()), inp, permAttr); @@ -108,7 +110,7 @@ SmallVector castVariadicInput(PatternRewriter &rewriter, Location loc, ValueRange inputs, IntegerAttr saturate, TypeAttr to) { SmallVector castInputs; for (Value inp : inputs) { - ShapedType inpType = inp.getType().cast(); + ShapedType inpType = mlir::cast(inp.getType()); assert(inpType && "Type is not ShapedType"); ONNXCastOp castOp = rewriter.create(loc, UnrankedTensorType::get(inpType.getElementType()), inp, saturate, to); @@ -121,7 +123,7 @@ SmallVector castVariadicInput(PatternRewriter &rewriter, Location loc, // Check if all values are produced by ONNXTransposeOp. bool areProducedByTransposeOp(ValueRange values) { return llvm::all_of(values, [](Value v) { - if (v.isa()) + if (mlir::isa(v)) return false; return isa(v.getDefiningOp()); }); @@ -131,7 +133,7 @@ bool areProducedByTransposeOp(ValueRange values) { DenseElementsAttr createDenseElementsAttrFromShape(PatternRewriter &rewriter, Value value, int64_t start = 0, std::optional end = std::nullopt) { - auto inType = value.getType().cast(); + auto inType = mlir::cast(value.getType()); assert(inType.hasRank() && "inType must be ranked"); auto shape = inType.getShape(); int64_t rank = inType.getRank(); @@ -167,7 +169,7 @@ bool AreTheSameAxesArrayAttr( auto asSet = [rank](ArrayRef array) { llvm::SmallSet axes; for (auto attr : array) { - int64_t axis = attr.cast().getInt(); + int64_t axis = mlir::cast(attr).getInt(); axes.insert(axis < 0 ? axis + rank : axis); } return axes; @@ -211,11 +213,11 @@ bool isNegativeSplatConstant(Value val) { if (!valAttr.isSplat()) return false; - Type elemTy = val.getType().cast().getElementType(); - if (elemTy.isa()) { + Type elemTy = mlir::cast(val.getType()).getElementType(); + if (mlir::isa(elemTy)) { double v = valAttr.getSplatValue(); return (v < 0.0); - } else if (elemTy.isa()) { + } else if (mlir::isa(elemTy)) { int64_t v = valAttr.getSplatValue(); return (v < 0); } @@ -226,15 +228,15 @@ bool isNegativeSplatConstant(Value val) { bool areAllDimSizes(ValueRange vals) { return llvm::all_of(vals, [](Value val) { // Block arguments. - if (val.isa()) + if (mlir::isa(val)) return false; // Defined by DimOp. if (val.getDefiningOp()) return true; // Defined by ConstantOp. if (isDenseONNXConstant(val) && isScalarTensor(val)) { - Type elemTy = val.getType().cast().getElementType(); - if (!elemTy.isa()) + Type elemTy = mlir::cast(val.getType()).getElementType(); + if (!mlir::isa(elemTy)) return false; ONNXConstantOp constOp = val.getDefiningOp(); auto valAttr = @@ -255,7 +257,7 @@ bool areAllDimSizes(ValueRange vals) { // A and B are constants. bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, Operation *&matmulOrGemmOp, Operation *&addOp, bool &isGemm) { - if (v.isa()) + if (mlir::isa(v)) return false; if (!hasOneUseExceptDimOp(v)) return false; @@ -270,7 +272,7 @@ bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, if (!hasOneUseExceptDimOp(origV)) break; } - if (origV.isa() || !hasOneUseExceptDimOp(origV)) + if (mlir::isa(origV) || !hasOneUseExceptDimOp(origV)) return false; // Match Gemm @@ -745,14 +747,14 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // A helper function to check whether a value is defined by ONNXConstantOp in // the same block or not. bool isDefinedByIntegerConstantOp(Value v) const { - if (v.isa()) + if (mlir::isa(v)) return false; Operation *definingOp = v.getDefiningOp(); - if (v.getType().cast().getElementType().isa() && + if (mlir::isa( + mlir::cast(v.getType()).getElementType()) && isa(definingOp) && - cast(definingOp) - .getValueAttr() - .isa()) + mlir::isa( + cast(definingOp).getValueAttr())) return true; return false; } @@ -762,38 +764,40 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // shifted by 1 to the left in YieldOp. If a block argument is unchanged when // being shifted in YieldOp, then it is invariant to iterations. bool isInvariantBlockArg(Value v, Operation *yieldOp) const { - return v.isa() && - (v == yieldOp->getOperands()[v.cast().getArgNumber() - - 1]); + return mlir::isa(v) && + (v == + yieldOp + ->getOperands()[mlir::cast(v).getArgNumber() - + 1]); } // A helper function to check whether a value is defined by ONNXConstantOp in // the same block or an invariant block argument. bool isIntConstantOrInvariantBlockArg(Value v, Operation *yieldOp) const { - return ((v.isa() && isInvariantBlockArg(v, yieldOp)) || - (!v.isa() && isDefinedByIntegerConstantOp(v))); + return ((mlir::isa(v) && isInvariantBlockArg(v, yieldOp)) || + (!mlir::isa(v) && isDefinedByIntegerConstantOp(v))); } // A helper function to check whether an block argument is updated by a Value // inside the loop or not. bool isUpdatedArgByValue(Value v, Value newV, Operation *yieldOp) const { - return v.isa() && + return mlir::isa(v) && (newV == yieldOp - ->getOperands()[v.cast().getArgNumber() - 1]); + ->getOperands()[mlir::cast(v).getArgNumber() - + 1]); } // A helper function to get the value that is fed to an operation's argument. Value getFedValue(Value arg, Operation *op) const { - return op->getOperands()[arg.cast().getArgNumber()]; + return op->getOperands()[mlir::cast(arg).getArgNumber()]; } // A helper function to get an integer constant from a value. int64_t getOneIntegerConstant(Value v) const { Operation *definingOp = v.getDefiningOp(); - DenseElementsAttr valueAttr = cast(definingOp) - .getValueAttr() - .cast(); + DenseElementsAttr valueAttr = mlir::cast( + cast(definingOp).getValueAttr()); return (*valueAttr.getValues().begin()).getSExtValue(); } @@ -840,7 +844,7 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // The break condition is the first argument of YieldOp. // `ONNXYieldOp (cond, ..., ubValue, ..., newCounterValue, ...)` Value breakCond = yieldOp->getOperands()[0]; - if (breakCond.isa()) + if (mlir::isa(breakCond)) return std::make_pair(false, maxTripCountValue); Operation *breakCondOp = breakCond.getDefiningOp(); @@ -852,10 +856,8 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { Value newCounterValue = breakCondOp->getOperands()[0]; Value ubValue = breakCondOp->getOperands()[1]; // Input type of Less must be integer. - if (!newCounterValue.getType() - .cast() - .getElementType() - .isa()) + if (!mlir::isa( + mlir::cast(newCounterValue.getType()).getElementType())) return std::make_pair(false, maxTripCountValue); // Compute a trip count from the break condition, given that the upper bound @@ -863,7 +865,7 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // iteration. So, the trip count will be `(upper_bound - lower_bound)/step`. // Only support ONNXAddOp at this moment. - if (newCounterValue.isa() || + if (mlir::isa(newCounterValue) || !isa(newCounterValue.getDefiningOp())) return std::make_pair(false, maxTripCountValue); // ONNXLoop(max_trip_count, true, ..., ubValue, ..., startValue, ...) @@ -925,8 +927,9 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { SmallVector values(1, derivedTripCount); DenseElementsAttr valueAttr = DenseElementsAttr::get( - RankedTensorType::get({}, - maxTripCountValue.getType().cast().getElementType()), + RankedTensorType::get( + {}, mlir::cast(maxTripCountValue.getType()) + .getElementType()), ArrayRef(values)); return std::make_pair(true, onnx.constant(valueAttr)); } @@ -936,14 +939,14 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // - new_max_trip_count = // min(old_max_trip_count, ceil(upper_bound - lower_bound)/step) TypeAttr tripCountType = TypeAttr::get( - maxTripCountValue.getType().cast().getElementType()); + mlir::cast(maxTripCountValue.getType()).getElementType()); // Cast the upper and lower bounds to the correct type. - if (maxTripCountValue.getType().cast().getElementType() != - ubValue.getType().cast().getElementType()) + if (mlir::cast(maxTripCountValue.getType()).getElementType() != + mlir::cast(ubValue.getType()).getElementType()) ubValue = onnx.cast(ubValue, tripCountType); - if (maxTripCountValue.getType().cast().getElementType() != - lbValue.getType().cast().getElementType()) + if (mlir::cast(maxTripCountValue.getType()).getElementType() != + mlir::cast(lbValue.getType()).getElementType()) lbValue = onnx.cast(lbValue, tripCountType); // Emit code to compute the max trip count. @@ -987,14 +990,14 @@ class InputOutputTransposer { void transposeInput(MutableOperandRange operand, ArrayAttr perm) { assert(operand.size() == 1 && "should be called with singleton range"); Value input = operand[0].get(); - if (!input.getType().isa()) { + if (!mlir::isa(input.getType())) { Value transposed = transpose(input, perm); operand.assign(transposed); } } void transposeOutput(Value output, ArrayAttr perm) { - if (!output.getType().isa()) { + if (!mlir::isa(output.getType())) { Value transposed = transpose(output, perm); output.replaceAllUsesExcept(transposed, transposed.getDefiningOp()); } @@ -1153,7 +1156,7 @@ class PowToMulRewritePattern : public OpRewritePattern { Value input = powOp.getX(); Value result = nullptr; - ShapedType resultType = powOp.getZ().getType().cast(); + ShapedType resultType = mlir::cast(powOp.getZ().getType()); Type elementType = getElementType(resultType); if (exponent == 0) { Attribute one = isa(elementType) diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index 9e50b07e01..0625bb0274 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -36,7 +36,7 @@ include "src/Dialect/ONNX/ONNX.td" // Create a DenseElementsAttr from a float attribute and an element type. def createDenseElementsAttrFromFloatAttr : NativeCodeCall< - "onnx_mlir::createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast().getElementType(), $1)">; + "onnx_mlir::createDenseElementsAttrFromFloatAttr($_builder, mlir::cast($0.getType()).getElementType(), $1)">; // Create a DenseElementsAttr from the shape of the type of a value. def createDenseElementsAttrFromShape : NativeCodeCall< @@ -60,7 +60,7 @@ def subtractOrNeg: NativeCodeCall< // Get the rank of the given value. def getRankOf : - NativeCodeCall<"$0.getType().cast().getRank()">; + NativeCodeCall<"mlir::cast($0.getType()).getRank()">; // Create an ArrayAttr of IntergerAttr(s) of [$0]. def createDenseElementsAttrOf : NativeCodeCall< @@ -68,22 +68,22 @@ def createDenseElementsAttrOf : NativeCodeCall< // Create an ArrayAttr of IntergerAttr(s) of values in [1, N-1]. def createDenseElementsAttrOfOneToRankOf : NativeCodeCall< - "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, $0.getType().cast().getRank() - 1)">; + "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast($0.getType()).getRank() - 1)">; // Create an ArrayAttr of IntergerAttr(s) of values in [1, N-2]. def createDenseElementsAttrOfOneToRankOfExclusive : NativeCodeCall< - "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, $0.getType().cast().getRank() - 2)">; + "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast($0.getType()).getRank() - 2)">; // Create an ArrayAttr of IntergerAttr(s) of values in [2, rank - 1]. def createArrayAttrOfTwoToRankOf : NativeCodeCall< - "onnx_mlir::createArrayAttrOfNToM($_builder, 2, $0.getType().cast().getRank() - 1)">; + "onnx_mlir::createArrayAttrOfNToM($_builder, 2, mlir::cast($0.getType()).getRank() - 1)">; def AttributeIsNotNull : Constraint, "Attribute is not null">; def IsDenseElementsAttr : - Constraint, - CPred<" ($_self).isa()"> + Constraint, + CPred<"mlir::isa(($_self))"> ]>, "Attribute is not a DenseElementsAttr">; // Intended to check whether there is at least one not-Null the attributes @@ -109,29 +109,29 @@ def HasNonZeroInArrayAttr: Constraint, // Check the rank of a value is greater than a given integer. class HasRankGT : - Constraint() && " - "$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() > " # rank>>; + Constraint($0.getType()) && " + "mlir::cast($0.getType()).hasRank() && " + "mlir::cast($0.getType()).getRank() > " # rank>>; // Check the rank of a value is of a given integer. class HasRankOf : - Constraint() && " - "$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # rank>>; + Constraint($0.getType()) && " + "mlir::cast($0.getType()).hasRank() && " + "mlir::cast($0.getType()).getRank() == " # rank>>; def HaveSameLastDim: Constraint< CPred<"onnx_mlir::hasShapeAndRank($0) && onnx_mlir::hasShapeAndRank($1) && " - "($0.getType().cast().getShape()" - "[$0.getType().cast().getRank() - 1] == " - "$1.getType().cast().getShape()" - "[$1.getType().cast().getRank() - 1])">, + "(mlir::cast($0.getType()).getShape()" + "[mlir::cast($0.getType()).getRank() - 1] == " + "mlir::cast($1.getType()).getShape()" + "[mlir::cast($1.getType()).getRank() - 1])">, "Two tensors have the same last dimension">; class HaveSameDim: Constraint< CPred<"onnx_mlir::hasShapeAndRank($0) && onnx_mlir::hasShapeAndRank($1) && " - "!$0.getType().cast().isDynamicDim(" # dim # ") && " - "($0.getType().cast().getShape()[" # dim # "] ==" - "$1.getType().cast().getShape()[" # dim # "])">, + "!mlir::cast($0.getType()).isDynamicDim(" # dim # ") && " + "(mlir::cast($0.getType()).getShape()[" # dim # "] ==" + "mlir::cast($1.getType()).getShape()[" # dim # "])">, "Two tensors have the same specified dimension">; def HaveSameShapedType: Constraint< @@ -153,29 +153,29 @@ def CreateNoneValue : NativeCodeCall<"$_builder.create($_loc).getRes def HasOneUse : Constraint>; -def HasNoneType : Constraint()">>; +def HasNoneType : Constraint($0.getType())">>; -def NotNoneType : Constraint())">>; +def NotNoneType : Constraint(($0.getType()))">>; def HasShapeAndRank : Constraint>; def HasSameElementType : Constraint< - CPred<"($0.getType().dyn_cast().getElementType() == " - "$1.cast<::mlir::TypeAttr>().getValue())">, + CPred<"(mlir::dyn_cast($0.getType()).getElementType() == " + "mlir::cast<::mlir::TypeAttr>($1).getValue())">, "has same element type">; def HaveSameElementType : Constraint< - CPred<"($0.getType().dyn_cast().getElementType() == " - "$1.getType().dyn_cast().getElementType())">, + CPred<"(mlir::dyn_cast($0.getType()).getElementType() == " + "mlir::dyn_cast($1.getType()).getElementType())">, "have same element types">; def HaveSameElementTypeBitWidth: Constraint< - CPred<"($0.getType().dyn_cast().getElementTypeBitWidth() == " - "$1.getType().dyn_cast().getElementTypeBitWidth())">, + CPred<"(mlir::dyn_cast($0.getType()).getElementTypeBitWidth() == " + "mlir::dyn_cast($1.getType()).getElementTypeBitWidth())">, "has same element type bitwidth">; def ElementTypeIsNotUnsigned: Constraint< - CPred<"!$_self.getType().dyn_cast().getElementType().isUnsignedInteger()">, + CPred<"!mlir::dyn_cast($_self.getType()).getElementType().isUnsignedInteger()">, "element type is not unsigned int">; def HaveSameEncodingAttr: Constraint< @@ -185,7 +185,7 @@ def HaveSameEncodingAttr: Constraint< def IsStaticShapeTensor: Constraint< CPred< - "$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">, + "mlir::cast<::mlir::ShapedType>($_self.getType()).hasStaticShape()">, "hasStaticShape">; def IsNoneValue: Constraint< @@ -221,32 +221,32 @@ def AreAllDimSizes: Constraint< def AreTheSameAxesConstant: Constraint< CPred<"onnx_mlir::AreTheSameAxesConstant(" - "(onnx_mlir::hasShapeAndRank($0) ? $0.getType().cast().getRank() : 0)," + "(onnx_mlir::hasShapeAndRank($0) ? mlir::cast($0.getType()).getRank() : 0)," "$1, $2)">, "Two values are constants with the same axis values">; def AreTheSameAxesArrayAttr: Constraint< CPred<"onnx_mlir::AreTheSameAxesArrayAttr(" - "(onnx_mlir::hasShapeAndRank($0) ? $0.getType().cast().getRank() : 0)," + "(onnx_mlir::hasShapeAndRank($0) ? mlir::cast($0.getType()).getRank() : 0)," "$1, $2)">, "Two axis arrays are the same">; class AllDimsFromAxisToEndAre: Constraint< CPred<"llvm::all_of(" - "ArrayRef($_self.getType().cast().getShape().begin() + " # axis # "," - " $_self.getType().cast().getShape().end())," + "ArrayRef(mlir::cast($_self.getType()).getShape().begin() + " # axis # "," + " mlir::cast($_self.getType()).getShape().end())," "[](int64_t val) { return (val == " # val # ");})">, "All dimensions from axis to the end are val">; def DimAtIndexIsConstant: Constraint< CPred<"onnx_mlir::hasShapeAndRank($0) &&" - "!$0.getType().cast().isDynamicDim($1.getValue().getSExtValue())">, + "!mlir::cast($0.getType()).isDynamicDim($1.getValue().getSExtValue())">, "Dim at the given index is constant" >; class RankXMinusRankYIs: Constraint< - CPred<"($0.getType().cast().getRank() " - " - $1.getType().cast().getRank() == " # diff # ")">, + CPred<"(mlir::cast($0.getType()).getRank() " + " - mlir::cast($1.getType()).getRank() == " # diff # ")">, "X' rank is greater than Y's rank diff units">; def TransposeVariadicInput: NativeCodeCall< @@ -263,7 +263,7 @@ class EqualString : Constraint>; def AxisIsTheLastDim: Constraint< CPred<"($1.getValue().getSExtValue() == -1) ||" "(onnx_mlir::hasShapeAndRank($0) &&" - " ($0.getType().cast().getRank() == $1.getValue().getSExtValue() + 1))">, + " (mlir::cast($0.getType()).getRank() == $1.getValue().getSExtValue() + 1))">, "Axis is the last dimension of the input" >; @@ -408,16 +408,16 @@ def FuseMulConvNullBiasPattern: Pat< $b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides), [(HasNoneType $b), (IsDenseElementsAttr:$denseAttr), - (IsFromONNXConstantOpWithDenseElementsAttr:$w), - (HaveSameElementType $w, $y), // multiplier and Conv weight must have the same element type. - (HasRankGT<1> $w), // rank of $w must be at least 2. + (IsFromONNXConstantOpWithDenseElementsAttr:$w), + (HaveSameElementType $w, $y), // multiplier and Conv weight must have the same element type. + (HasRankGT<1> $w), // rank of $w must be at least 2. (RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1. (HaveSameDim<0> $w, $y), // the first dimension of $w and $y must be equal. (AllDimsFromAxisToEndAre<1, 1>:$y)] // all dimensions of $y must be 1 except for the first one. >; // TODO add pattern for non-null bias with contraints: -// - bias must be have rank equal to 1 and +// - bias must be have rank equal to 1 and // - bias element data type must be the same as mul constant // - bias dimension (0) must be equal to mul constant dim(0) // codegen is different too (look it up in onnx-runtime) @@ -477,8 +477,8 @@ def SwapCastSlicePattern: Pat< def IsFromONNXConstantOpWithOnesDenseElementsAttr: Constraint< And<[IsFromONNXConstantOpWithDenseElementsAttr.predicate, CPred<"::llvm::all_of(" - " onnx_mlir::getONNXConstantOp($_self).getValueAttr()" - " .dyn_cast().getValues(), " + "mlir::dyn_cast(onnx_mlir::getONNXConstantOp($_self)" + ".getValueAttr()).getValues(), " "[](int64_t repeat) { return repeat == 1;})"> ]>, "Value is not a ONNXConstantOp with a DenseElementsAttr of ones">; @@ -661,7 +661,7 @@ def RemoveIdentityReshapePattern1: Pat< [(IsNoneValue:$shape)]>; def RemoveIdentityReshapePattern2: Pat< - // Remove an identity pattern. Output and input shapes are static and the same. + // Remove an identity pattern. Output and input shapes are static and the same. (ONNXReshapeOp:$out $val, $_, $_), // Remove the reshape. (replaceWithValue $val), @@ -679,7 +679,7 @@ def SwapReshapeMatMulPattern: Pattern< // TODO: Support dynamic dimensions. (ONNXMatMulOp:$res2 (ONNXReshapeOp:$res1 $A, $_, $az), $B), [(ONNXReshapeOp (ONNXMatMulOp $A, $B, (returnType (GetReturnTypeForMatMulOpND2D $A, $B))), - (ONNXConstantOpFromDenseAttr + (ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromShape $res2) ), $az)], [(HasRankGT<2> $A), (HasRankOf<2> $res1), (HasRankOf<2> $B), // A is reshaped to 2D. diff --git a/src/Dialect/ONNX/ONNXOps/ControlFlow/If.cpp b/src/Dialect/ONNX/ONNXOps/ControlFlow/If.cpp index 1cab284591..3f0bd284d1 100644 --- a/src/Dialect/ONNX/ONNXOps/ControlFlow/If.cpp +++ b/src/Dialect/ONNX/ONNXOps/ControlFlow/If.cpp @@ -26,23 +26,24 @@ namespace { bool areCompatibleIfTypes(Type ifResultType, Type branchResultType) { // ifResultType must be tensor/seq/opt type because that's checked in // ONNXIfOp::verifyInvariantsImpl() - if (ShapedType ifShapedType = ifResultType.dyn_cast()) { - if (ShapedType branchShapedType = branchResultType.dyn_cast()) { + if (ShapedType ifShapedType = mlir::dyn_cast(ifResultType)) { + if (ShapedType branchShapedType = + mlir::dyn_cast(branchResultType)) { return ifShapedType.getElementType() == branchShapedType.getElementType(); } else { return false; } } - if (SeqType ifSeqType = ifResultType.dyn_cast()) { - if (SeqType branchSeqType = branchResultType.dyn_cast()) { + if (SeqType ifSeqType = mlir::dyn_cast(ifResultType)) { + if (SeqType branchSeqType = mlir::dyn_cast(branchResultType)) { return areCompatibleIfTypes( ifSeqType.getElementType(), branchSeqType.getElementType()); } else { return false; } } - if (OptType ifOptType = ifResultType.dyn_cast()) { - if (OptType branchOptType = branchResultType.dyn_cast()) { + if (OptType ifOptType = mlir::dyn_cast(ifResultType)) { + if (OptType branchOptType = mlir::dyn_cast(branchResultType)) { return areCompatibleIfTypes( ifOptType.getElementType(), branchOptType.getElementType()); } else { @@ -56,8 +57,8 @@ bool areCompatibleIfTypes(Type ifResultType, Type branchResultType) { // rhs) Type unionOfIfTypes(Type lhs, Type rhs) { // All asserts below are checked in areCompatibleIfTypes(). - if (ShapedType lhsShapedType = lhs.dyn_cast()) { - ShapedType rhsShapedType = rhs.cast(); + if (ShapedType lhsShapedType = mlir::dyn_cast(lhs)) { + ShapedType rhsShapedType = mlir::cast(rhs); Type elementType = lhsShapedType.getElementType(); assert(elementType == rhsShapedType.getElementType() && "tensor element types mismatch"); @@ -76,8 +77,8 @@ Type unionOfIfTypes(Type lhs, Type rhs) { return UnrankedTensorType::get(elementType); } } - if (SeqType lhsSeqType = lhs.dyn_cast()) { - SeqType rhsSeqType = rhs.cast(); + if (SeqType lhsSeqType = mlir::dyn_cast(lhs)) { + SeqType rhsSeqType = mlir::cast(rhs); int64_t length = lhsSeqType.getLength() == rhsSeqType.getLength() ? lhsSeqType.getLength() : -1; @@ -85,8 +86,8 @@ Type unionOfIfTypes(Type lhs, Type rhs) { rhsSeqType.getElementType()), length); } - if (OptType lhsOptType = lhs.dyn_cast()) { - OptType rhsOptType = rhs.cast(); + if (OptType lhsOptType = mlir::dyn_cast(lhs)) { + OptType rhsOptType = mlir::cast(rhs); return OptType::get(unionOfIfTypes( lhsOptType.getElementType(), rhsOptType.getElementType())); } diff --git a/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp b/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp index c17bf4e912..bcb229e6d0 100644 --- a/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp +++ b/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp @@ -99,7 +99,7 @@ LogicalResult ONNXLoopOp::inferShapes( auto bodyScanOutputTys = llvm::drop_begin(bodyOuputTys, numCarried); for (auto [opScanOutput, ty] : llvm::zip(scan_outputs(), bodyScanOutputTys)) { // TODO: Handle SeqType, OptType. - if (auto rankedTy = ty.dyn_cast()) { + if (auto rankedTy = mlir::dyn_cast(ty)) { SmallVector unsqueezedShape(rankedTy.getShape()); // Note that we may know the extent of the scan output leading // dimension, which is very likely just the trip count specified as an diff --git a/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp b/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp index 8daa1f4a2b..5c62a5fa95 100644 --- a/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp +++ b/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp @@ -68,7 +68,8 @@ LogicalResult ONNXScanOp::inferShapes( assert(!getScanOutputAxes() && "scan_output_axes are unsupported"); assert(!scan_inputs().empty() && "there must be 1 or more scan inputs"); - auto firstScanInputType = scan_inputs().front().getType().cast(); + auto firstScanInputType = + mlir::cast(scan_inputs().front().getType()); // Number of body iterations is the dim size of the scan input sequence axis, // which is also the dim size of the scan outputs concat axis. int64_t sequence_length = firstScanInputType.hasRank() @@ -93,7 +94,8 @@ LogicalResult ONNXScanOp::inferShapes( auto bodyScanInputs = llvm::drop_begin(bodyInputs, numStateVariables); for (auto [opScanInput, bodyScanInput] : llvm::zip(scan_inputs(), bodyScanInputs)) { - if (auto rankedTy = opScanInput.getType().dyn_cast()) { + if (auto rankedTy = + mlir::dyn_cast(opScanInput.getType())) { ArrayRef squeezedShape(rankedTy.getShape().drop_front(1)); updateType(getOperation(), bodyScanInput, squeezedShape, rankedTy.getElementType(), /*encoding=*/nullptr, @@ -121,7 +123,7 @@ LogicalResult ONNXScanOp::inferShapes( // with an extra leading dimension. auto bodyScanOutputTys = llvm::drop_begin(bodyOuputTys, numStateVariables); for (auto [opScanOutput, ty] : llvm::zip(scan_outputs(), bodyScanOutputTys)) { - if (auto rankedTy = ty.dyn_cast()) { + if (auto rankedTy = mlir::dyn_cast(ty)) { SmallVector unsqueezedShape(rankedTy.getShape()); unsqueezedShape.insert(unsqueezedShape.begin(), sequence_length); updateType(getOperation(), opScanOutput, unsqueezedShape, diff --git a/src/Dialect/ONNX/ONNXOps/ML/CategoryMapper.cpp b/src/Dialect/ONNX/ONNXOps/ML/CategoryMapper.cpp index 1148ef8c29..dd13df6b26 100644 --- a/src/Dialect/ONNX/ONNXOps/ML/CategoryMapper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ML/CategoryMapper.cpp @@ -46,9 +46,9 @@ LogicalResult ONNXCategoryMapperOp::verify() { return success(); } - ShapedType inputType = X.getType().cast(); + ShapedType inputType = mlir::cast(X.getType()); Type elementType = inputType.getElementType(); - if (!elementType.isInteger(64) && !elementType.isa()) + if (!elementType.isInteger(64) && !mlir::isa(elementType)) return emitOpError("input must be a tensor of int64 or string"); // Check attributes. @@ -61,7 +61,7 @@ LogicalResult ONNXCategoryMapperOp::verify() { if (elementType.isInteger(64) && !getDefaultStringAttr()) return emitOpError("'default_string' attribute is missing."); - if (elementType.isa() && !getDefaultInt64Attr()) + if (mlir::isa(elementType) && !getDefaultInt64Attr()) return emitOpError("'default_int64' attribute is missing."); return success(); @@ -77,9 +77,10 @@ LogicalResult ONNXCategoryMapperOp::inferShapes( if (!hasShapeAndRank(getX())) return success(); - Type inputElementType = getX().getType().cast().getElementType(); + Type inputElementType = + mlir::cast(getX().getType()).getElementType(); assert((inputElementType.isInteger(64) || - inputElementType.isa()) && + mlir::isa(inputElementType)) && "Input tensor must have int64 or string element type."); Type outputElementType; diff --git a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp index 710aa9bf35..3b3805ea26 100644 --- a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp +++ b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp @@ -29,7 +29,7 @@ LogicalResult ONNXOneHotEncoderOpShapeHelper::computeShape() { ONNXOneHotEncoderOp oneHotOp = llvm::cast(op); ONNXOneHotEncoderOpAdaptor operandAdaptor(operands); Value X = operandAdaptor.getX(); - ShapedType inputType = X.getType().dyn_cast(); + ShapedType inputType = mlir::dyn_cast(X.getType()); assert(inputType && "expected ranked type"); // If the input is a tensor of float, int32, or double, @@ -65,7 +65,7 @@ LogicalResult ONNXOneHotEncoderOp::verify() { if (!hasShapeAndRank(input)) return success(); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); if (!inputType) return success(); diff --git a/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp b/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp index 234414741e..be4b5a6694 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp @@ -38,7 +38,7 @@ LogicalResult ONNXBernoulliOp::inferShapes( (onnx::TensorProto_DataType)getDtypeAttr().getValue().getSExtValue()); } else { elementType = - getInput().getType().cast().getElementType(); + mlir::cast(getInput().getType()).getElementType(); } ONNXBernoulliOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); diff --git a/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp b/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp index 2912b0a57d..82047008a2 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp @@ -97,7 +97,8 @@ LogicalResult ONNXDFTOp::inferShapes( if (!isNoneValue(getAxis()) && !hasShapeAndRank(getAxis())) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXDFTOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/Einsum.cpp b/src/Dialect/ONNX/ONNXOps/Math/Einsum.cpp index f2148b1383..030c84acdc 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Einsum.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Einsum.cpp @@ -67,9 +67,9 @@ LogicalResult ONNXEinsumOp::verify() { } Type firstElementType = - inputs[0].getType().cast().getElementType(); + mlir::cast(inputs[0].getType()).getElementType(); for (Value input : inputs) { - ShapedType type = input.getType().cast(); + ShapedType type = mlir::cast(input.getType()); if (type.getElementType() != firstElementType) { return emitOpError() << "different input element types"; } @@ -90,7 +90,7 @@ LogicalResult ONNXEinsumOp::inferShapes( return success(); // Can only infer once operand shapes are known. Type elementType = - getOperand(0).getType().cast().getElementType(); + mlir::cast(getOperand(0).getType()).getElementType(); ONNXEinsumOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp b/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp index dbdcb96464..af472ec6ba 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp @@ -223,7 +223,7 @@ FailureOr inferSignature( for (size_t i = 0; i < inputs.size(); ++i) { Value input = inputs[i]; StringRef equationInput = equationInputs[i]; - ShapedType type = input.getType().cast(); + ShapedType type = mlir::cast(input.getType()); auto shape = type.getShape(); size_t rank = shape.size(); size_t letters = countLetters(equationInput); diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp index 0c3092d66d..f3961667e5 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp @@ -76,7 +76,7 @@ static LogicalResult inferShapeForBroadcastingOps( if (!elementType) elementType = - op.getOperand(0).getType().template cast().getElementType(); + mlir::cast(op.getOperand(0).getType()).getElementType(); ONNXBroadcastOpShapeHelper shapeHelper(op.getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -290,18 +290,18 @@ LogicalResult ONNXMinOp::inferShapes( LogicalResult ONNXModOp::verify() { Type elementType; - if (getA().getType().isa()) - elementType = getA().getType().cast().getElementType(); + if (mlir::isa(getA().getType())) + elementType = mlir::cast(getA().getType()).getElementType(); else return emitOpError("Input type must be TensorType or MemRefType"); // Verify that when the input type is floating point, then `fmod` attribute // must be set to 1. - if (elementType.isa() && (getFmod() != 1)) + if (mlir::isa(elementType) && (getFmod() != 1)) return emitOpError("fmod must be 1 when the input type is floating point"); // Verify that when the input type is integer, then `fmod` attribute // must be set to 0. - if (elementType.isa() && (getFmod() != 0)) + if (mlir::isa(elementType) && (getFmod() != 0)) return emitOpError("fmod must be 0 when the input type is an integer"); return success(); @@ -343,13 +343,13 @@ LogicalResult ONNXOrOp::inferShapes( //===----------------------------------------------------------------------===// LogicalResult ONNXPowOp::verify() { - ShapedType lhsTy = getX().getType().cast(); - ShapedType rhsTy = getY().getType().cast(); + ShapedType lhsTy = mlir::cast(getX().getType()); + ShapedType rhsTy = mlir::cast(getY().getType()); Type rhsETy = rhsTy.getElementType(); Type lhsETy = lhsTy.getElementType(); if (rhsETy != lhsETy) return emitOpError("Pow with different input type not implemented yet"); - if (lhsETy.isa() || lhsETy.isa()) + if (mlir::isa(lhsETy) || mlir::isa(lhsETy)) return emitOpError("Integer power not implemented yet"); return success(); } @@ -369,9 +369,10 @@ LogicalResult ONNXPReluOp::verify() { if (!hasShapeAndRank(getSlope())) return success(); - ArrayRef xShape = getX().getType().cast().getShape(); + ArrayRef xShape = + mlir::cast(getX().getType()).getShape(); ArrayRef slopeShape = - getSlope().getType().cast().getShape(); + mlir::cast(getSlope().getType()).getShape(); // PRelu supports unidirectional broadcasting, that is slope should be // unidirectional broadcast to input X. if (slopeShape.size() > xShape.size()) @@ -384,7 +385,7 @@ LogicalResult ONNXPReluOp::inferShapes( if (!hasShapeAndRank(getOperation())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXPReluOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -425,7 +426,8 @@ LogicalResult ONNXWhereOp::verify() { LogicalResult ONNXWhereOp::inferShapes( std::function doShapeInference) { - Type resultElementType = getX().getType().cast().getElementType(); + Type resultElementType = + mlir::cast(getX().getType()).getElementType(); return inferShapeForBroadcastingOps(*this, resultElementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp index a19d9cbfc6..a38ddfcb11 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp @@ -34,7 +34,8 @@ LogicalResult inferShapeForUnaryOps(Operation *op) { Value input = op->getOperand(0); if (!hasShapeAndRank(input)) return success(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = + mlir::dyn_cast(input.getType()); return inferShapeForUnaryOps( op, inputType.getElementType(), inputType.getEncoding()); } @@ -45,7 +46,8 @@ LogicalResult inferShapeForUnaryOps(Operation *op, Type elementType) { Value input = op->getOperand(0); if (!hasShapeAndRank(input)) return success(); - RankedTensorType inputType = input.getType().dyn_cast(); + RankedTensorType inputType = + mlir::dyn_cast(input.getType()); return inferShapeForUnaryOps(op, elementType, inputType.getEncoding()); } @@ -148,7 +150,7 @@ LogicalResult ONNXCastOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - Type elementType = (*this)->getAttr("to").cast<::TypeAttr>().getValue(); + Type elementType = mlir::cast<::TypeAttr>((*this)->getAttr("to")).getValue(); ONNXCastOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -262,7 +264,7 @@ LogicalResult ONNXGeluOp::verify() { LogicalResult ONNXGeluOp::inferShapes( std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation(), - this->getResult().getType().cast().getElementType()); + mlir::cast(this->getResult().getType()).getElementType()); } //===----------------------------------------------------------------------===// @@ -317,7 +319,7 @@ LogicalResult ONNXIsInfOp::verify() { LogicalResult ONNXIsInfOp::inferShapes( std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation(), - this->getResult().getType().cast().getElementType()); + mlir::cast(this->getResult().getType()).getElementType()); } //===----------------------------------------------------------------------===// @@ -358,7 +360,7 @@ LogicalResult ONNXLogSoftmaxOp::verify() { return success(); // Won't be able to do any checking at this stage. int64_t inputRank = - operandAdaptor.getInput().getType().cast().getRank(); + mlir::cast(operandAdaptor.getInput().getType()).getRank(); int64_t axisIndex = getAxis(); // axis attribute must be in the range [-r,r-1], where r = rank(input). @@ -448,7 +450,7 @@ LogicalResult ONNXScalerOp::inferShapes( return success(); ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {}); - RankedTensorType xType = getX().getType().dyn_cast(); + RankedTensorType xType = mlir::dyn_cast(getX().getType()); return shapeHelper.computeShapeAndUpdateType( FloatType::getF32(getContext()), xType.getEncoding()); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp b/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp index eb75c090eb..ab26c52771 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp @@ -38,13 +38,13 @@ LogicalResult ONNXGemmOpShapeHelper::computeShape() { hasBias = !isNoneValue(C); // Test ranks. - if (A.getType().cast().getShape().size() != 2) + if (mlir::cast(A.getType()).getShape().size() != 2) return op->emitError("Gemm with A should be a 2D tensor"); - if (B.getType().cast().getShape().size() != 2) + if (mlir::cast(B.getType()).getShape().size() != 2) return op->emitError("Gemm with B should be a 2D tensor"); cRank = 0; if (hasBias) { - cRank = C.getType().cast().getShape().size(); + cRank = mlir::cast(C.getType()).getShape().size(); if (cRank > 2) return op->emitError("Gemm with C should be a 1D or 2D tensor"); } @@ -123,7 +123,7 @@ LogicalResult ONNXGemmOp::inferShapes( (hasBias && !hasShapeAndRank(getC()))) return success(); - Type elementType = getA().getType().cast().getElementType(); + Type elementType = mlir::cast(getA().getType()).getElementType(); ONNXGemmOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/HardMax.cpp b/src/Dialect/ONNX/ONNXOps/Math/HardMax.cpp index 9531c2aed0..e2219fc4a4 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/HardMax.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/HardMax.cpp @@ -30,7 +30,7 @@ LogicalResult ONNXHardmaxOp::verify() { // axis attribute must be in the range [-r,r-1], where r = rank(input). int64_t axisValue = getAxis(); - int64_t inputRank = input.getType().cast().getRank(); + int64_t inputRank = mlir::cast(input.getType()).getRank(); if (axisValue < -inputRank || axisValue >= inputRank) return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError( *this->getOperation(), "axis", axisValue, @@ -48,7 +48,7 @@ LogicalResult ONNXHardmaxOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); int64_t inputRank = inputType.getRank(); int64_t axisValue = getAxis(); diff --git a/src/Dialect/ONNX/ONNXOps/Math/LRN.cpp b/src/Dialect/ONNX/ONNXOps/Math/LRN.cpp index 10fa47d01a..41170a3c6e 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/LRN.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/LRN.cpp @@ -43,7 +43,7 @@ LogicalResult ONNXLRNOpShapeHelper::computeShape() { LogicalResult ONNXLRNOp::inferShapes( std::function doShapeInference) { - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXLRNOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp index 9ffc7c9a2c..c17b62679d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp @@ -169,7 +169,7 @@ LogicalResult ONNXMatMulOp::inferShapes( if (!hasShapeAndRank(getA()) || !hasShapeAndRank(getB())) return success(); - Type elementType = getA().getType().cast().getElementType(); + Type elementType = mlir::cast(getA().getType()).getElementType(); ONNXMatMulOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -184,7 +184,8 @@ LogicalResult ONNXMatMulIntegerOp::inferShapes( if (!hasShapeAndRank(getA()) || !hasShapeAndRank(getB())) return success(); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); ONNXMatMulIntegerOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -200,8 +201,8 @@ LogicalResult ONNXMatMulIntegerOp::verify() { Value A = operandAdaptor.getA(); Value aZeroPoint = this->getAZeroPoint(); if (!isNoneValue(aZeroPoint)) { - auto aType = A.getType().cast(); - auto aZeroPointType = aZeroPoint.getType().cast(); + auto aType = mlir::cast(A.getType()); + auto aZeroPointType = mlir::cast(aZeroPoint.getType()); uint64_t aRank = aType.getRank(); uint64_t aZeroPointRank = aZeroPointType.getRank(); ArrayRef aShape = aType.getShape(); @@ -292,7 +293,8 @@ LogicalResult ONNXQLinearMatMulOp::inferShapes( if (!hasShapeAndRank(getA()) || !hasShapeAndRank(getB())) return success(); - Type elementType = getResult().getType().cast().getElementType(); + Type elementType = + mlir::cast(getResult().getType()).getElementType(); ONNXQLinearMatMulOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp index 308bf3a002..9df2bbe18b 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp @@ -31,8 +31,10 @@ LogicalResult ONNXRandomNormalLikeOp::verify() { if (!hasShapeAndRank(output)) return success(); - auto inputType = input.getType().cast().getElementType(); - auto outputType = output.getType().cast().getElementType(); + auto inputType = + mlir::cast(input.getType()).getElementType(); + auto outputType = + mlir::cast(output.getType()).getElementType(); auto elementTypeIDDType = operandAdaptor.getDtype(); if (elementTypeIDDType) { @@ -63,7 +65,7 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes( std::function doShapeInference) { if (!hasShapeAndRank(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); auto elementTypeIDDType = getDtype(); // Default output tensor type in all cases is the input tensor type. diff --git a/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp b/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp index 7bcc2e273f..d988f67bc1 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp @@ -130,7 +130,7 @@ LogicalResult ONNXGenericReductionOpShapeHelper::computeShape() { // there, from input putting question mark in there. Not sure if // successful, if it is, it should be generalized to all ops. OP_TYPE reduceOp = llvm::cast(op); - if (reduceOp.getResult().getType().template isa()) { + if (mlir::isa(reduceOp.getResult().getType())) { // Have already some shapes, keep them in ShapeHelper DimsExpr outputDims; createIE->getShapeAsDims(reduceOp.getResult(), outputDims); @@ -163,7 +163,7 @@ static LogicalResult inferShapeForReductionOps_old(OP_TYPE &op) { return success(); ShapedType dataType = - operandAdaptor.getData().getType().template cast(); + mlir::cast(operandAdaptor.getData().getType()); ONNXGenericReductionOpShapeHelper shapeHelper(op.getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(dataType.getElementType()); } @@ -181,7 +181,7 @@ static LogicalResult inferShapeForReductionOps(OP_TYPE &op) { return success(); ShapedType dataType = - operandAdaptor.getData().getType().template cast(); + mlir::cast(operandAdaptor.getData().getType()); ONNXGenericReductionOpShapeHelper shapeHelper(op.getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(dataType.getElementType()); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp index 23304b2e3b..189d855805 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp @@ -40,9 +40,9 @@ LogicalResult ONNXScatterElementsOp::verify() { Value data = operandAdaptor.getData(); Value indices = operandAdaptor.getIndices(); Value updates = operandAdaptor.getUpdates(); - auto dataType = data.getType().cast(); - auto indicesType = indices.getType().cast(); - auto updatesType = updates.getType().cast(); + auto dataType = mlir::cast(data.getType()); + auto indicesType = mlir::cast(indices.getType()); + auto updatesType = mlir::cast(updates.getType()); int64_t dataRank = dataType.getRank(); int64_t indicesRank = indicesType.getRank(); int64_t updatesRank = updatesType.getRank(); @@ -110,9 +110,9 @@ LogicalResult ONNXScatterNDOp::verify() { Value data = operandAdaptor.getData(); Value indices = operandAdaptor.getIndices(); Value updates = operandAdaptor.getUpdates(); - auto dataType = data.getType().cast(); - auto indicesType = indices.getType().cast(); - auto updatesType = updates.getType().cast(); + auto dataType = mlir::cast(data.getType()); + auto indicesType = mlir::cast(indices.getType()); + auto updatesType = mlir::cast(updates.getType()); int64_t dataRank = dataType.getRank(); int64_t indicesRank = indicesType.getRank(); int64_t updatesRank = updatesType.getRank(); diff --git a/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp b/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp index a71a1aa44f..641faa1e4d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp @@ -74,7 +74,7 @@ LogicalResult ONNXTopKOp::verify() { Value K = operandAdaptor.getK(); if (hasShapeAndRank(K)) { // K's rank must be zero or one. - int64_t KRank = K.getType().cast().getRank(); + int64_t KRank = mlir::cast(K.getType()).getRank(); if (KRank > 1) return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError( *this->getOperation(), K, KRank, "< 2"); @@ -83,7 +83,7 @@ LogicalResult ONNXTopKOp::verify() { // axis attribute must be in the range [-r,r-1], where r = rank(X). Value X = operandAdaptor.getX(); if (hasShapeAndRank(X)) { - int64_t Xrank = X.getType().cast().getRank(); + int64_t Xrank = mlir::cast(X.getType()).getRank(); int64_t axis = this->getAxis(); if (axis < -Xrank || axis >= Xrank) @@ -106,7 +106,7 @@ LogicalResult ONNXTopKOp::inferShapes( return success(); Builder b(getContext()); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXTopKOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateTypes({elementType, b.getI64Type()}); } diff --git a/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp b/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp index 7ebe3fe1ce..951905f8ad 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp @@ -229,8 +229,9 @@ LogicalResult processConvStrideParam( template LogicalResult processConvTypeParams(T *op, Value inputOperand, Value W) { // 1) Get shape of input. Shape is not guaranteed to be compile time constant. - auto inputShape = inputOperand.getType().cast().getShape(); - auto wShape = W.getType().cast().getShape(); + auto inputShape = + mlir::cast(inputOperand.getType()).getShape(); + auto wShape = mlir::cast(W.getType()).getShape(); // If kernel_shape isn't provided, add kernel_shape to the the op based on the // shape of the input and weights. @@ -507,7 +508,7 @@ LogicalResult ONNXConvOp::verify() { // Won't be able to do any checking at this stage. return success(); } - auto wShape = W.getType().cast().getShape(); + auto wShape = mlir::cast(W.getType()).getShape(); int64_t spatialRank = wShape.size() - 2; // If ranked, verify ranks of inputs. if (spatialRank < 1) @@ -530,7 +531,7 @@ LogicalResult ONNXConvOp::verify() { "Channel Out (M) must be a multiple of the number of groups"); } if (hasShapeAndRank(X)) { - auto xShape = X.getType().cast().getShape(); + auto xShape = mlir::cast(X.getType()).getShape(); if ((int64_t)xShape.size() - 2 != spatialRank) return emitOpError("Input and filter rank mismatch"); if (xShape[1] != ShapedType::kDynamic && xShape[1] % g != 0) @@ -543,7 +544,7 @@ LogicalResult ONNXConvOp::verify() { } } if (hasBias && hasShapeAndRank(B)) { - auto bShape = B.getType().cast().getShape(); + auto bShape = mlir::cast(B.getType()).getShape(); if (bShape.size() != 1) return emitOpError("Bias should have a rank of one"); if (bShape[0] != ShapedType::kDynamic && @@ -601,7 +602,7 @@ LogicalResult ONNXConvTransposeOp::verify() { auto X = operandAdaptor.getX(); auto W = operandAdaptor.getW(); auto B = operandAdaptor.getB(); - bool hasBias = !B.getType().isa(); + bool hasBias = !mlir::isa(B.getType()); int64_t g = getGroup(); if (g < 1) return emitOpError("group must be strictly positive"); @@ -610,14 +611,14 @@ LogicalResult ONNXConvTransposeOp::verify() { // Won't be able to do any checking at this stage. return success(); } - auto wShape = W.getType().cast().getShape(); + auto wShape = mlir::cast(W.getType()).getShape(); int64_t spatialRank = wShape.size() - 2; // If ranked, verify ranks of inputs. if (spatialRank < 1) return emitOpError("Spatial rank must be strictly positive"); if (hasShapeAndRank(X)) { - auto xShape = X.getType().cast().getShape(); + auto xShape = mlir::cast(X.getType()).getShape(); if ((int64_t)xShape.size() - 2 != spatialRank) return emitOpError("Input and filter rank mismatch"); if (xShape[1] != ShapedType::kDynamic && @@ -627,7 +628,7 @@ LogicalResult ONNXConvTransposeOp::verify() { } } if (hasBias && hasShapeAndRank(B)) { - auto bShape = B.getType().cast().getShape(); + auto bShape = mlir::cast(B.getType()).getShape(); if (bShape.size() != 1) return emitOpError("Bias should have a rank of one"); if (bShape[0] != ShapedType::kDynamic && @@ -713,14 +714,14 @@ LogicalResult ONNXQLinearConvOp::inferShapes( bool hasBias = !isNoneValue(B()); // Cannot infer shape if no shape exists. - if (!getX().getType().isa() || - !getW().getType().isa() || - (hasBias && !getB().getType().isa())) + if (!mlir::isa(getX().getType()) || + !mlir::isa(getW().getType()) || + (hasBias && !mlir::isa(getB().getType()))) return success(); - auto xTy = getX().getType().cast(); + auto xTy = mlir::cast(getX().getType()); auto xShape = xTy.getShape(); - auto weightTy = getW().getType().cast(); + auto weightTy = mlir::cast(getW().getType()); auto weightShape = weightTy.getShape(); auto builder = Builder(this->getContext()); @@ -747,7 +748,7 @@ LogicalResult ONNXQLinearConvOp::inferShapes( // Check the size of bias. if (hasBias) { - auto bTx = getB().getType().cast(); + auto bTx = mlir::cast(getB().getType()); auto bShape = bTx.getShape(); if (bShape.size() != 1) return emitError("bias should be one dimensional"); diff --git a/src/Dialect/ONNX/ONNXOps/NN/Dropout.cpp b/src/Dialect/ONNX/ONNXOps/NN/Dropout.cpp index f12ebfcd83..7885466ff3 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Dropout.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Dropout.cpp @@ -56,7 +56,7 @@ LogicalResult ONNXDropoutOp::inferShapes( return success(); Type outputElementType = - getData().getType().cast().getElementType(); + mlir::cast(getData().getType()).getElementType(); IntegerType maskElementType = IntegerType::get(getContext(), 1, IntegerType::Signless); ONNXDropoutOpShapeHelper shapeHelper(getOperation(), {}); diff --git a/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc b/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc index 0d23814f63..d742a01bbd 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc +++ b/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc @@ -181,7 +181,7 @@ static LogicalResult verifyKernelShape(T *op, Value filterOperand, // 1) Get shape of filter. Shape is not guaranteed to be compile time // constant. ArrayRef filterShape = - filterOperand ? filterOperand.getType().cast().getShape() + filterOperand ? mlir::cast(filterOperand.getType()).getShape() : ArrayRef(); // 2) Get kernel_shape attribute if (!kernelShapeOpt.has_value()) { diff --git a/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp b/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp index f7f6b31ae8..df7f3c2d56 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp @@ -43,11 +43,11 @@ LogicalResult ONNXBatchNormalizationInferenceModeOp::inferShapes( return success(); // Verifier code. - auto inputTensorTy = getX().getType().cast(); - auto scaleTensorTy = getScale().getType().cast(); - auto biasTensorTy = getB().getType().cast(); - auto meanTensorTy = getMean().getType().cast(); - auto varianceTensorTy = getVar().getType().cast(); + auto inputTensorTy = mlir::cast(getX().getType()); + auto scaleTensorTy = mlir::cast(getScale().getType()); + auto biasTensorTy = mlir::cast(getB().getType()); + auto meanTensorTy = mlir::cast(getMean().getType()); + auto varianceTensorTy = mlir::cast(getVar().getType()); // Check whether the shapes of scale, bias, mean and variance are valid. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. @@ -79,7 +79,8 @@ LogicalResult ONNXBatchNormalizationInferenceModeOp::inferShapes( } // The output tensor of the same shape as the input. - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXBatchNormalizationInferenceModeOpShapeHelper shapeHelper( getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); @@ -107,7 +108,7 @@ LogicalResult ONNXInstanceNormalizationOp::verify() { // Won't be able to do any checking at this stage. return success(); } - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto inputShape = inputType.getShape(); auto inputElementType = inputType.getElementType(); int64_t spatialRank = inputShape.size() - 2; @@ -118,7 +119,7 @@ LogicalResult ONNXInstanceNormalizationOp::verify() { // Check bias B. if (hasShapeAndRank(B)) { // Can check at this stage. - auto bType = B.getType().cast(); + auto bType = mlir::cast(B.getType()); auto bShape = bType.getShape(); if (bShape.size() != 1) return emitOpError("Bias should have a rank of one"); @@ -133,7 +134,7 @@ LogicalResult ONNXInstanceNormalizationOp::verify() { // Check scale. if (hasShapeAndRank(scale)) { // Can check at this stage. - auto scaleType = scale.getType().cast(); + auto scaleType = mlir::cast(scale.getType()); auto scaleShape = scaleType.getShape(); if (scaleShape.size() != 1) return emitOpError("Scale should have a rank of one"); @@ -171,7 +172,7 @@ LogicalResult verifyShapeForLayerNorm(OP_TYPE *op) { // Won't be able to do any checking at this stage. return success(); } - ShapedType XType = X.getType().cast(); + ShapedType XType = mlir::cast(X.getType()); ArrayRef XShape = XType.getShape(); int64_t XRank = XShape.size(); Type XElementType = XType.getElementType(); @@ -184,7 +185,7 @@ LogicalResult verifyShapeForLayerNorm(OP_TYPE *op) { // Check bias B. if (hasShapeAndRank(B)) { // Can check at this stage. - ShapedType bType = B.getType().cast(); + ShapedType bType = mlir::cast(B.getType()); ArrayRef bShape = bType.getShape(); SmallVector BBroadcastShape; if (!OpTrait::util::getBroadcastedShape(XShape, bShape, BBroadcastShape)) @@ -200,7 +201,7 @@ LogicalResult verifyShapeForLayerNorm(OP_TYPE *op) { // Check scale. if (hasShapeAndRank(scale)) { // Can check at this stage. - ShapedType scaleType = scale.getType().cast(); + ShapedType scaleType = mlir::cast(scale.getType()); ArrayRef scaleShape = scaleType.getShape(); SmallVector scaleBroadcastShape; if (!OpTrait::util::getBroadcastedShape( @@ -226,7 +227,7 @@ mlir::LogicalResult ONNXLNOpShapeHelper::computeShape() { // Get rank and axis attribute. Value X = operandAdaptor.getX(); - int64_t XRank = X.getType().cast().getRank(); + int64_t XRank = mlir::cast(X.getType()).getRank(); int64_t axis = getAxisInRange(lnOp.getAxis(), XRank); // Check optional outputs, with specialization for ONNXLayerNormalizationOp @@ -289,7 +290,8 @@ LogicalResult ONNXLayerNormalizationOp::inferShapes( if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getScale()) || (!isNoneValue(getB()) && !hasShapeAndRank(getB()))) return success(); - Type commonType = getX().getType().cast().getElementType(); + Type commonType = + mlir::cast(getX().getType()).getElementType(); ONNXLayerNormalizationOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(commonType); } @@ -309,7 +311,8 @@ LogicalResult ONNXRMSLayerNormalizationOp::inferShapes( if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getScale()) || (!isNoneValue(getB()) && !hasShapeAndRank(getB()))) return success(); - Type commonType = getX().getType().cast().getElementType(); + Type commonType = + mlir::cast(getX().getType()).getElementType(); ONNXRMSLayerNormalizationOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(commonType); } diff --git a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp index 95fca161ee..27cce02696 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp @@ -106,7 +106,7 @@ LogicalResult ONNXAveragePoolOp::verify() { // Get operands. auto X = operandAdaptor.getX(); if (hasShapeAndRank(X)) { - auto xShape = X.getType().cast().getShape(); + auto xShape = mlir::cast(X.getType()).getShape(); if ((int64_t)xShape.size() - 2 != spatialRank) return emitOpError("Input and kernel shape rank mismatch"); } @@ -130,7 +130,7 @@ LogicalResult ONNXAveragePoolOp::inferShapes( if (!hasShapeAndRank(getX())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXAveragePoolOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -144,7 +144,7 @@ LogicalResult ONNXGlobalAveragePoolOp::inferShapes( if (!hasShapeAndRank(getX())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXGlobalAveragePoolOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -158,7 +158,7 @@ LogicalResult ONNXGlobalLpPoolOp::inferShapes( if (!hasShapeAndRank(getX())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXGlobalLpPoolOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -172,7 +172,7 @@ LogicalResult ONNXGlobalMaxPoolOp::inferShapes( if (!hasShapeAndRank(getX())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXGlobalMaxPoolOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -214,7 +214,7 @@ LogicalResult ONNXMaxPoolSingleOutOp::verify() { // Get operands. auto X = operandAdaptor.getX(); if (hasShapeAndRank(X)) { - auto xShape = X.getType().cast().getShape(); + auto xShape = mlir::cast(X.getType()).getShape(); if (static_cast(xShape.size()) - 2 != spatialRank) return emitOpError("Input and kernel shape rank mismatch"); } @@ -242,7 +242,7 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes( auto kernelShape = getKernelShape(); assert(kernelShape && "verified that we had kernel shape"); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); IndexExprBuilderForAnalysis createIE(getLoc()); ONNXMaxPoolSingleOutOpShapeHelper shapeHelper(getOperation(), {}, &createIE); return shapeHelper.computeShapeAndUpdateType(elementType); @@ -257,7 +257,8 @@ LogicalResult ONNXMaxRoiPoolOp::inferShapes( if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getRois())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXMaxRoiPoolOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp b/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp index f516e26024..1e833aac54 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp @@ -62,9 +62,9 @@ LogicalResult ONNXRoiAlignOp::verify() { if (!hasShapeAndRank(X) || !hasShapeAndRank(batch_indices)) return success(); - int64_t x_rank = X.getType().cast().getRank(); + int64_t x_rank = mlir::cast(X.getType()).getRank(); int64_t batch_indices_rank = - batch_indices.getType().cast().getRank(); + mlir::cast(batch_indices.getType()).getRank(); // Test ranks. if (x_rank != 4) @@ -85,7 +85,7 @@ LogicalResult ONNXRoiAlignOp::inferShapes( if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getBatchIndices())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); ONNXRoiAlignOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/ObjectDetection/NonMaxSuppression.cpp b/src/Dialect/ONNX/ONNXOps/ObjectDetection/NonMaxSuppression.cpp index 944c05a4df..90b11eb9da 100644 --- a/src/Dialect/ONNX/ONNXOps/ObjectDetection/NonMaxSuppression.cpp +++ b/src/Dialect/ONNX/ONNXOps/ObjectDetection/NonMaxSuppression.cpp @@ -52,7 +52,7 @@ LogicalResult ONNXNonMaxSuppressionOp::verify() { // Check operands. if (hasShapeAndRank(boxes)) { - auto shape = boxes.getType().cast().getShape(); + auto shape = mlir::cast(boxes.getType()).getShape(); if (shape.size() != 3) return emitOpError("boxes should have a rank of three"); if (!ShapedType::isDynamic(shape[2]) && shape[2] != 4) @@ -60,20 +60,20 @@ LogicalResult ONNXNonMaxSuppressionOp::verify() { } if (hasShapeAndRank(scores)) - if (scores.getType().cast().getRank() != 3) + if (mlir::cast(scores.getType()).getRank() != 3) return emitOpError("scores should have a rank of three"); if (hasShapeAndRank(MOPC)) - if (MOPC.getType().cast().getRank() > 1) + if (mlir::cast(MOPC.getType()).getRank() > 1) return emitOpError( "max_output_boxex_per_class should have a rank of zero or one"); if (hasShapeAndRank(scoreThreshold)) - if (scoreThreshold.getType().cast().getRank() > 1) + if (mlir::cast(scoreThreshold.getType()).getRank() > 1) return emitOpError("score_threshold should have a rank of zero or one"); if (hasShapeAndRank(iouThreshold)) - if (iouThreshold.getType().cast().getRank() > 1) + if (mlir::cast(iouThreshold.getType()).getRank() > 1) return emitOpError("iou_threshold should have a rank of zero or one"); return success(); diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 6f0356fa18..3c12168e8f 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -77,15 +77,15 @@ StringRef convertONNXTensorDataLayoutToString( } bool isONNXTensor(const Type type) { - if (auto ttp = type.dyn_cast()) - if (ttp.getEncoding().dyn_cast_or_null()) + if (auto ttp = mlir::dyn_cast(type)) + if (mlir::dyn_cast_or_null(ttp.getEncoding())) return true; return false; } ONNXTensorEncodingAttr getONNXTensorEncoding(Type type) { - if (auto ttp = type.dyn_cast()) - return ttp.getEncoding().dyn_cast_or_null(); + if (auto ttp = mlir::dyn_cast(type)) + return mlir::dyn_cast_or_null(ttp.getEncoding()); return nullptr; } @@ -293,22 +293,22 @@ size_t ArrayAttrSize(ArrayAttr a) { return a.size(); } size_t ArrayAttrSize(std::optional a) { return a.value().size(); } int64_t ArrayAttrIntVal(ArrayAttr a, int i) { - return (a.getValue()[i]).cast().getInt(); + return mlir::cast(a.getValue()[i]).getInt(); } int64_t ArrayAttrIntVal(std::optional a, int i) { - return (a.value().getValue()[i]).cast().getInt(); + return mlir::cast(a.value().getValue()[i]).getInt(); } void ArrayAttrIntVals(ArrayAttr a, mlir::SmallVectorImpl &i) { for (size_t k = 0; k < a.size(); ++k) - i.emplace_back((a.getValue()[k]).cast().getInt()); + i.emplace_back(mlir::cast(a.getValue()[k]).getInt()); } ElementsAttr getElementAttributeFromONNXValue(Value value) { ONNXConstantOp constantOp = getONNXConstantOp(value); if (constantOp) - return constantOp.getValueAttr().dyn_cast(); + return mlir::dyn_cast(constantOp.getValueAttr()); return nullptr; } @@ -339,12 +339,12 @@ ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter, // Read first permute vectors. SmallVector initialPerm; for (auto firstPermVal : firstPermAttr.getValue()) - initialPerm.emplace_back(firstPermVal.cast().getInt()); + initialPerm.emplace_back(mlir::cast(firstPermVal).getInt()); // Read second permute vector. Use it as an index in the first permute // vector. SmallVector resPerm; for (auto secondPermVal : secondPermAttr.getValue()) { - auto index = secondPermVal.cast().getInt(); + auto index = mlir::cast(secondPermVal).getInt(); resPerm.emplace_back(initialPerm[index]); } // Convert to Array of Attributes. @@ -359,7 +359,7 @@ bool IsIdentityPermuteVector(ArrayAttr permAttr) { return false; int64_t currentIndex = 0; for (auto permVal : permAttr.getValue()) - if (permVal.cast().getInt() != currentIndex++) + if (mlir::cast(permVal).getInt() != currentIndex++) return false; return true; } @@ -369,7 +369,8 @@ bool HasSpecifiedConstantShape(Value value, Value shape) { if (!hasShapeAndRank(value) || !hasShapeAndRank(shape)) return false; - ArrayRef valueShape = value.getType().cast().getShape(); + ArrayRef valueShape = + mlir::cast(value.getType()).getShape(); ElementsAttr shapeAttr = getElementAttributeFromONNXValue(shape); if (shapeAttr == nullptr) return false; @@ -403,12 +404,12 @@ bool isScalarConstantTensor(mlir::Value v) { bool hasShapeAndRank(Value val) { Type valType = val.getType(); ShapedType shapedType; - if (SeqType seqType = valType.dyn_cast()) - shapedType = seqType.getElementType().dyn_cast(); - else if (OptType optType = valType.dyn_cast()) - shapedType = optType.getElementType().dyn_cast(); + if (SeqType seqType = mlir::dyn_cast(valType)) + shapedType = mlir::dyn_cast(seqType.getElementType()); + else if (OptType optType = mlir::dyn_cast(valType)) + shapedType = mlir::dyn_cast(optType.getElementType()); else - shapedType = valType.dyn_cast(); + shapedType = mlir::dyn_cast(valType); return shapedType && shapedType.hasRank(); } @@ -505,7 +506,7 @@ void getDims(Value val, SmallVectorImpl &dims) { // Create a DenseElementsAttr based on the shape of type at the given index. DenseElementsAttr createDenseElementsAttrFromShapeAtIndex( PatternRewriter &rewriter, Value value, IntegerAttr indexAttr) { - auto inType = value.getType().cast(); + auto inType = mlir::cast(value.getType()); ArrayRef shape = inType.getShape(); int64_t index = indexAttr.getValue().getSExtValue(); SmallVector values(1, shape[index]); @@ -516,7 +517,7 @@ DenseElementsAttr createDenseElementsAttrFromShapeAtIndex( // Create a DenseElementsAttr based on the size of type. DenseElementsAttr createDenseElementsAttrFromSize( PatternRewriter &rewriter, Value value) { - auto inType = value.getType().cast(); + auto inType = mlir::cast(value.getType()); // Output Type should be scalar: tensor SmallVector dims; SmallVector values = {inType.getNumElements()}; @@ -555,8 +556,8 @@ RESULT_TYPE getScalarValue(ElementsAttr denseAttr, Type type) { if (elementaryType.isInteger(16) || elementaryType.isInteger(32) || elementaryType.isInteger(64)) { auto valueIt = denseAttr.getValues().begin(); - return (RESULT_TYPE)(*valueIt).cast().getInt(); - } else if (elementaryType.isa()) { + return (RESULT_TYPE)mlir::cast(*valueIt).getInt(); + } else if (mlir::isa(elementaryType)) { auto valueIt = denseAttr.getValues().begin(); return (RESULT_TYPE)(*valueIt).convertToDouble(); } @@ -567,7 +568,7 @@ RESULT_TYPE getScalarValue(ElementsAttr denseAttr, Type type) { template RESULT_TYPE getScalarValue(ONNXConstantOp constantOp) { Type type = constantOp.getType(); - ElementsAttr attr = constantOp.getValueAttr().dyn_cast(); + ElementsAttr attr = mlir::dyn_cast(constantOp.getValueAttr()); if (!attr) constantOp.emitError("ElementsAttr expected"); return getScalarValue(attr, type); @@ -654,9 +655,9 @@ int64_t mlirTypeToOnnxType(Type elemType) { .Case( [&](BFloat16Type) { onnxType = onnx::TensorProto::BFLOAT16; }) .Case([&](ComplexType type) { - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) onnxType = onnx::TensorProto::COMPLEX64; - else if (type.getElementType().isa()) + else if (mlir::isa(type.getElementType())) onnxType = onnx::TensorProto::COMPLEX128; }) .Case( @@ -716,14 +717,14 @@ bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) { if (elementAttr.getNumElements() != 1) return false; Type elementType = elementAttr.getElementType(); - if (elementType.isa()) { + if (mlir::isa(elementType)) { double floatVal = getScalarValue(elementAttr, elementType); if (floatVal == ceil(floatVal)) { // We essentially have an integer value represented as a float. exponentValue = (int64_t)floatVal; return true; } - } else if (elementType.isa()) { + } else if (mlir::isa(elementType)) { exponentValue = getScalarValue(elementAttr, elementType); return true; } @@ -819,17 +820,17 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) { } // Try with op location. Location loc = op->getLoc(); - if (auto nameLoc = loc.dyn_cast()) { + if (auto nameLoc = mlir::dyn_cast(loc)) { return nameLoc.getName().str(); } - if (auto fusedLoc = loc.dyn_cast()) { + if (auto fusedLoc = mlir::dyn_cast(loc)) { // Combine each location name and set it as nodeName. std::string name; for (Location locIt : fusedLoc.getLocations()) { - if (auto nameLocIt = locIt.dyn_cast()) + if (auto nameLocIt = mlir::dyn_cast(locIt)) name += nameLocIt.getName().str() + "-"; else if (useFileLine) { - if (auto fileLineColLoc = locIt.dyn_cast()) { + if (auto fileLineColLoc = mlir::dyn_cast(locIt)) { getNameFromFileLineLoc(fileLineColLoc, name, "-"); } } @@ -841,7 +842,7 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) { return name; } if (useFileLine) { - if (auto fileLineColLoc = loc.dyn_cast()) { + if (auto fileLineColLoc = mlir::dyn_cast(loc)) { std::string name = ""; getNameFromFileLineLoc(fileLineColLoc, name); return name; diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc index bf0d5d46b7..8ad0944af9 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc @@ -30,7 +30,7 @@ inline bool isNoneValue(mlir::Value value) { /// Check the defining operation of a value. template bool definedBy(mlir::Value v) { - return !v.isa() && llvm::isa(v.getDefiningOp()); + return !mlir::isa(v) && llvm::isa(v.getDefiningOp()); } // Support for recognizing patterns. Detects if the operation "op" has an input @@ -61,7 +61,7 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op, mlir::Value operand = op->getOperand(matchThisOperandIndex); // operand.dump(); // Check for a match with definition of operand. - if (!operand.isa() && + if (!mlir::isa(operand) && llvm::isa(operand.getDefiningOp())) { matchOperand = operand; matchOp = operand.getDefiningOp(); diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp index 3e45aff839..4728b8f2a9 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp @@ -47,17 +47,17 @@ LogicalResult ONNXDequantizeLinearOpShapeHelper::computeShape() { ONNXDequantizeLinearOpAdaptor operandAdaptor( operands, op->getAttrDictionary()); RankedTensorType xTy = - operandAdaptor.getX().getType().dyn_cast(); + mlir::dyn_cast(operandAdaptor.getX().getType()); DimsExpr outputDims; createIE->getShapeAsDims(operandAdaptor.getX(), outputDims); // Get d. - int64_t d = - nonScalar1DLen(operandAdaptor.getXScale().getType().cast()); + int64_t d = nonScalar1DLen( + mlir::cast(operandAdaptor.getXScale().getType())); if (d == ShapedType::kDynamic && !isNoneValue(operandAdaptor.getXZeroPoint())) { d = nonScalar1DLen( - operandAdaptor.getXZeroPoint().getType().cast()); + mlir::cast(operandAdaptor.getXZeroPoint().getType())); } if (d != ShapedType::kDynamic) { @@ -97,7 +97,7 @@ LogicalResult ONNXDequantizeLinearOp::verify() { }; Value scale = getXScale(); - auto scaleTy = scale.getType().cast(); + auto scaleTy = mlir::cast(scale.getType()); if (scaleTy.hasRank() && scaleTy.getRank() > 1) return emitOpError("x_scale must be a scalar or 1-D tensor"); int64_t scaleLen = nonScalar1DLen(scaleTy); @@ -105,11 +105,11 @@ LogicalResult ONNXDequantizeLinearOp::verify() { Value zero = getXZeroPoint(); int64_t zeroLen = ShapedType::kDynamic; if (!isNoneValue(zero)) { - if (auto zeroTy = zero.getType().dyn_cast()) { + if (auto zeroTy = mlir::dyn_cast(zero.getType())) { if (zeroTy.getRank() > 1) return emitOpError("x_zero_point must be a scalar or 1-D tensor"); zeroLen = nonScalar1DLen(zeroTy); - if (auto scaleTy = scale.getType().dyn_cast()) { + if (auto scaleTy = mlir::dyn_cast(scale.getType())) { if ((isScalar(scaleTy) && scaleLen != ShapedType::kDynamic) || (zeroLen != ShapedType::kDynamic && isScalar(zeroTy)) || (zeroLen != ShapedType::kDynamic && @@ -142,7 +142,7 @@ LogicalResult ONNXDequantizeLinearOp::verify() { // If x_scale or x_zero_point is a non-scalar 1-D tensor then quantization // is per-axis. int64_t d = scaleLen != ShapedType::kDynamic ? scaleLen : zeroLen; - if (auto xTy = getX().getType().dyn_cast()) { + if (auto xTy = mlir::dyn_cast(getX().getType())) { int64_t r = xTy.getRank(); // axis attribute must be in the range [-r,rShapedType::kDynamic]. int64_t a = getAxis(); @@ -169,9 +169,9 @@ LogicalResult ONNXDequantizeLinearOp::verify() { LogicalResult ONNXDequantizeLinearOp::inferShapes( std::function doShapeInference) { - if (!getX().getType().dyn_cast()) + if (!mlir::dyn_cast(getX().getType())) return success(); - Type elementType = getY().getType().cast().getElementType(); + Type elementType = mlir::cast(getY().getType()).getElementType(); ONNXDequantizeLinearOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp index 321e66500c..f383509561 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp @@ -55,7 +55,7 @@ LogicalResult ONNXDynamicQuantizeLinearOpShapeHelper::computeShape() { LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes( std::function doShapeInference) { - auto inTy = getX().getType().dyn_cast(); + auto inTy = mlir::dyn_cast(getX().getType()); if (!inTy) return success(); diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/QuantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/QuantizeLinear.cpp index 3591830e23..7ed2f1549b 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/QuantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/QuantizeLinear.cpp @@ -46,7 +46,7 @@ LogicalResult ONNXQuantizeLinearOpShapeHelper::computeShape() { LogicalResult ONNXQuantizeLinearOp::inferShapes( std::function doShapeInference) { - auto inTy = getX().getType().dyn_cast(); + auto inTy = mlir::dyn_cast(getX().getType()); if (!inTy) { return success(); } @@ -58,7 +58,7 @@ LogicalResult ONNXQuantizeLinearOp::inferShapes( elementType = IntegerType::get(getContext(), 8, IntegerType::Unsigned); } else { // If zero point is provided, output type is same as zero point type. - elementType = zero.getType().cast().getElementType(); + elementType = mlir::cast(zero.getType()).getElementType(); } ONNXQuantizeLinearOpShapeHelper shapeHelper(getOperation(), {}); diff --git a/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp b/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp index dbd7440fc9..aa536b3df9 100644 --- a/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp +++ b/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp @@ -168,7 +168,8 @@ LogicalResult ONNXGRUOp::inferShapes( !hasShapeAndRank(getR())) { return success(); } - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXGRUOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -183,7 +184,8 @@ LogicalResult ONNXLSTMOp::inferShapes( !hasShapeAndRank(getR())) { return success(); } - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXLSTMOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } @@ -198,7 +200,8 @@ LogicalResult ONNXRNNOp::inferShapes( !hasShapeAndRank(getR())) { return success(); } - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXRNNOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp b/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp index 7cd2dfe8f0..3e46327a7f 100644 --- a/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp @@ -76,9 +76,9 @@ LogicalResult ONNXSequenceAtOp::inferShapes( std::function doShapeInference) { auto outputType = getResult().getType(); auto inputElementType = - getInputSequence().getType().cast().getElementType(); - if (!inputElementType.isa() && - outputType.isa()) { + mlir::cast(getInputSequence().getType()).getElementType(); + if (!mlir::isa(inputElementType) && + mlir::isa(outputType)) { getResult().setType(inputElementType); } return success(); @@ -91,9 +91,10 @@ LogicalResult ONNXSequenceAtOp::inferShapes( LogicalResult ONNXSequenceConstructOp::inferShapes( std::function doShapeInference) { auto types = getInputs().getTypes(); - ShapedType seqTensorType = types[0].cast(); + ShapedType seqTensorType = mlir::cast(types[0]); for (size_t i = 1; i < types.size(); ++i) { - seqTensorType = sequenceAddType(seqTensorType, types[i].cast()); + seqTensorType = + sequenceAddType(seqTensorType, mlir::cast(types[i])); } getResult().setType(SeqType::get(seqTensorType, types.size())); return success(); @@ -115,8 +116,8 @@ LogicalResult ONNXSequenceEmptyOp::verify() { } // Get element type for seq from the output - auto outputSeqElementType = - getResult().getType().cast().getElementType().cast(); + auto outputSeqElementType = mlir::cast( + mlir::cast(getResult().getType()).getElementType()); if (outputSeqElementType.getElementType() != elementType) return emitError("SequenceEmpty getDtype() does not match the output type"); return success(); @@ -124,7 +125,7 @@ LogicalResult ONNXSequenceEmptyOp::verify() { LogicalResult ONNXSequenceEmptyOp::inferShapes( std::function doShapeInference) { - auto originTy = getResult().getType().cast(); + auto originTy = mlir::cast(getResult().getType()); auto elementTy = originTy.getElementType(); auto returnTy = SeqType::get(elementTy, 0); getResult().setType(returnTy); @@ -137,7 +138,7 @@ LogicalResult ONNXSequenceEmptyOp::inferShapes( LogicalResult ONNXSequenceEraseOp::inferShapes( std::function doShapeInference) { - auto inputTy = getInputSequence().getType().cast(); + auto inputTy = mlir::cast(getInputSequence().getType()); int64_t length = inputTy.getLength(); if (length == 0) @@ -156,13 +157,13 @@ LogicalResult ONNXSequenceInsertOp::verify() { ONNXSequenceInsertOpAdaptor(*this); // These cast should be guaranteed by default verifier - Type seqElementType = operandAdaptor.getInputSequence() - .getType() - .dyn_cast() - .getElementType(); - Type elementType1 = seqElementType.dyn_cast().getElementType(); + Type seqElementType = + mlir::dyn_cast(operandAdaptor.getInputSequence().getType()) + .getElementType(); + Type elementType1 = + mlir::dyn_cast(seqElementType).getElementType(); ShapedType insertType = - operandAdaptor.getTensor().getType().dyn_cast(); + mlir::dyn_cast(operandAdaptor.getTensor().getType()); Type elementType2 = insertType.getElementType(); if (elementType1 != elementType2) { @@ -175,8 +176,8 @@ LogicalResult ONNXSequenceInsertOp::verify() { LogicalResult ONNXSequenceInsertOp::inferShapes( std::function doShapeInference) { // Merge the tensor type for the seq and the inserted tensor - SeqType seqType = getInputSequence().getType().cast(); - ShapedType tensorType = getTensor().getType().cast(); + SeqType seqType = mlir::cast(getInputSequence().getType()); + ShapedType tensorType = mlir::cast(getTensor().getType()); int64_t length = seqType.getLength(); if (length == 0) { // When the input seq is empty, inherit the tensor type @@ -184,7 +185,7 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( } else { int64_t newLength = length == ShapedType::kDynamic ? ShapedType::kDynamic : length + 1; - ShapedType seqTensorType = seqType.getElementType().cast(); + ShapedType seqTensorType = mlir::cast(seqType.getElementType()); seqTensorType = sequenceAddType(seqTensorType, tensorType); getResult().setType(SeqType::get(seqTensorType, newLength)); } @@ -198,8 +199,8 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( LogicalResult ONNXSequenceLengthOp::inferShapes( std::function doShapeInference) { Type outputTy = getResult().getType(); - if (!outputTy.isa() || - outputTy.cast().getRank() != 0) { + if (!mlir::isa(outputTy) || + mlir::cast(outputTy).getRank() != 0) { SmallVector dims; auto builder = Builder(getContext()); Type scalarTy = RankedTensorType::get(dims, builder.getIntegerType(64)); diff --git a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp index cd84fea009..3a17990e56 100644 --- a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp @@ -28,7 +28,7 @@ LogicalResult ONNXSplitToSequenceOp::verify() { if (!hasShapeAndRank(inputValue)) return success(); // Won't be able to do any checking at this stage. - auto inputType = inputValue.getType().cast(); + auto inputType = mlir::cast(inputValue.getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); @@ -52,7 +52,7 @@ LogicalResult ONNXSplitToSequenceOp::verify() { onnx_mlir::Diagnostic::Range(0, 1)); return success(); } - auto splitType = splitValue.getType().cast(); + auto splitType = mlir::cast(splitValue.getType()); ArrayRef splitShape = splitType.getShape(); int64_t splitRank = splitShape.size(); if (splitRank > 1) @@ -92,7 +92,7 @@ LogicalResult ONNXSplitToSequenceOp::inferShapes( // NOTE: all the asserts below are conditions checked in verify() - auto inputType = inputValue.getType().cast(); + auto inputType = mlir::cast(inputValue.getType()); ArrayRef shape = inputType.getShape(); int64_t rank = shape.size(); int64_t axisIndex = getAxis(); @@ -120,7 +120,7 @@ LogicalResult ONNXSplitToSequenceOp::inferShapes( dims.erase(dims.begin() + axisIndex); } } else { - auto splitType = splitValue.getType().cast(); + auto splitType = mlir::cast(splitValue.getType()); ArrayRef splitShape = splitType.getShape(); int64_t splitRank = splitShape.size(); assert(splitRank <= 1 && "invalid split tensor rank"); diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index 3e2d0c3235..33d1d45b1f 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -46,7 +46,7 @@ bool isAxisInRange(int64_t &axis, int64_t rank, bool includeRank) { } bool isAxisInRange(int64_t &axis, Value val, bool includeRank) { - ShapedType shapedType = val.getType().cast(); + ShapedType shapedType = mlir::cast(val.getType()); assert(shapedType && "expected a shaped type to determine the rank for axis"); return isAxisInRange(axis, shapedType.getRank(), includeRank); } @@ -204,7 +204,7 @@ LogicalResult ONNXOpShapeHelper::setOutputDimsFromLiterals( LogicalResult ONNXOpShapeHelper::setOutputDimsFromTypeWithConstantShape( Type type, int n, bool refineShape) { - RankedTensorType rankedType = type.dyn_cast(); + RankedTensorType rankedType = mlir::dyn_cast(type); if (!rankedType) return failure(); DimsExpr outputDims; @@ -221,12 +221,13 @@ LogicalResult ONNXOpShapeHelper::computeShapeAndUpdateType( // Invoke virtual compute shape. if (failed(computeShape())) return op->emitError("Failed to scan parameters successfully"); - assert((elementType.isa() || !elementType.isa()) && + assert((mlir::isa(elementType) || + !mlir::isa(elementType)) && "element type cannot be a shaped type other than vector type"); uint64_t resNum = op->getNumResults(); for (uint64_t i = 0; i < resNum; ++i) { // If we have an optional type, leave it as is. - if (op->getResults()[i].getType().isa()) + if (mlir::isa(op->getResults()[i].getType())) continue; llvm::SmallVector shapeVect; IndexExpr::getShape(getOutputDims(i), shapeVect); @@ -253,7 +254,7 @@ LogicalResult ONNXOpShapeHelper::computeShapeAndUpdateTypes( " parameters successfully"); for (uint64_t i = 0; i < resNum; ++i) { // If we have an optional type, leave it as is. - if (op->getResults()[i].getType().isa()) + if (mlir::isa(op->getResults()[i].getType())) continue; llvm::SmallVector shapeVect; IndexExpr::getShape(getOutputDims(i), shapeVect); @@ -442,7 +443,7 @@ bool ONNXBroadcastOpShapeHelper::hasNoBroadcast(DimAnalysis *dimAnalysis) { bool ONNXBroadcastOpShapeHelper::hasRankBroadcast() { ValueRange operands = this->operands; for (Value operand : operands) { - auto operandType = operand.getType().cast(); + auto operandType = mlir::cast(operand.getType()); if (outputRank != (uint64_t)operandType.getRank()) return true; } @@ -620,7 +621,7 @@ LogicalResult ONNXBroadcastOpShapeHelper::getAccessExprs(Value operand, // Get info. int64_t loopDepth = loopAccessExprs.size(); int64_t inputSize = inputsDims.size(); - int64_t operandRank = operand.getType().cast().getRank(); + int64_t operandRank = mlir::cast(operand.getType()).getRank(); // Flattened? no more than one loop per dim in output (aka output rank). // Not flattened? one loop per dim in output (aka output rank). if (flattenedInnerDims) @@ -826,7 +827,7 @@ void updateType(Operation *op, Value val, ArrayRef shape, elementType = getElementType(val.getType()); // Get encoding. - if (auto valType = val.getType().dyn_cast()) + if (auto valType = mlir::dyn_cast(val.getType())) if (!encoding) encoding = valType.getEncoding(); @@ -842,7 +843,7 @@ void updateType(Operation *op, Value val, ArrayRef shape, static void resetTypeShapeToQuestionmarks(Value val) { // Only deal with ranked tensor types here. - RankedTensorType valType = val.getType().dyn_cast(); + RankedTensorType valType = mlir::dyn_cast(val.getType()); if (!valType) return; // Reset any compile time literal to unknown (aka question marks). @@ -891,7 +892,8 @@ ONNXCustomOpShapeHelper::ONNXCustomOpShapeHelper(Operation *op, std::vector operandsVector; for (auto indexAttr : inputIndexAttrs.value()) { - operandsVector.push_back(inputs[indexAttr.cast().getInt()]); + operandsVector.push_back( + inputs[mlir::cast(indexAttr).getInt()]); } setOperands(ValueRange(operandsVector)); } diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index a58f9572f7..c99f6fce2d 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -30,7 +30,6 @@ #include "src/Dialect/Mlir/IndexExprBuilder.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" -#define GET_OP_FWD_DEFINES 1 #include "src/Dialect/ONNX/ONNXOps.hpp.inc" // ONNXOpShapeHelper is defined in the interface file below. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp index 6d6615f330..8b6762fd0c 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp @@ -30,7 +30,7 @@ LogicalResult ONNXArgMinMaxOpShapeHelper::computeShape() { OP_TYPE argOp = llvm::cast(op); typename OP_TYPE::Adaptor operandAdaptor(operands); Value data = operandAdaptor.getData(); - int64_t dataRank = data.getType().cast().getRank(); + int64_t dataRank = mlir::cast(data.getType()).getRank(); int64_t axisValue = argOp.getAxis(); // axis attribute must be in the range [-r,r-1], where r = rank(data). @@ -77,7 +77,7 @@ LogicalResult ONNXArgMaxOp::verify() { if (!hasShapeAndRank(getOperation())) return success(); - int64_t rank = getData().getType().cast().getRank(); + int64_t rank = mlir::cast(getData().getType()).getRank(); int64_t axisIndex = getAxis(); // axis value must be in the range [-rank, rank-1]. @@ -109,7 +109,7 @@ LogicalResult ONNXArgMinOp::verify() { if (!hasShapeAndRank(getOperation())) return success(); - int64_t rank = getData().getType().cast().getRank(); + int64_t rank = mlir::cast(getData().getType()).getRank(); int64_t axisIndex = getAxis(); // axis value must be in the range [-rank, rank-1]. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp index e4132f90a1..ec6a7a1cb4 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp @@ -83,7 +83,7 @@ LogicalResult ONNXCompressOp::verify() { if (!hasShapeAndRank(getOperation())) return success(); - int64_t inputRank = getInput().getType().cast().getRank(); + int64_t inputRank = mlir::cast(getInput().getType()).getRank(); std::optional optionalAxis = getAxis(); if (optionalAxis.has_value()) { @@ -95,7 +95,7 @@ LogicalResult ONNXCompressOp::verify() { onnx_mlir::Diagnostic::Range(-inputRank, inputRank - 1)); } - int64_t condRank = getCondition().getType().cast().getRank(); + int64_t condRank = mlir::cast(getCondition().getType()).getRank(); if (condRank != 1) return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError( *this->getOperation(), "condition", condRank, @@ -114,7 +114,8 @@ LogicalResult ONNXCompressOp::inferShapes( if (!hasShapeAndRank(getOperation())) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXCompressOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Concat.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Concat.cpp index 5505b505d3..8ec945d233 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Concat.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Concat.cpp @@ -32,7 +32,7 @@ LogicalResult ONNXConcatOpShapeHelper::computeShape() { unsigned numInputs = op->getNumOperands(); Value firstInput = operandAdaptor.getInputs().front(); ArrayRef commonShape = - firstInput.getType().cast().getShape(); + mlir::cast(firstInput.getType()).getShape(); int64_t commonRank = commonShape.size(); int64_t axisIndex = concatOp.getAxis(); @@ -90,7 +90,7 @@ LogicalResult ONNXConcatOp::verify() { return success(); auto commonType = - operandAdaptor.getOperands().front().getType().cast(); + mlir::cast(operandAdaptor.getOperands().front().getType()); ArrayRef commonShape = commonType.getShape(); int64_t commonRank = commonShape.size(); int64_t axisIndex = getAxis(); @@ -108,7 +108,7 @@ LogicalResult ONNXConcatOp::verify() { // of the axis to concatenate on. for (Value operand : operandAdaptor.getOperands()) { ArrayRef operandShape = - operand.getType().cast().getShape(); + mlir::cast(operand.getType()).getShape(); int64_t operandRank = operandShape.size(); if (operandRank != commonRank) return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError( @@ -141,7 +141,7 @@ LogicalResult ONNXConcatOp::inferShapes( if (!hasShapeAndRank(getOperation())) return success(); // Checking value of axis parameter. - auto commonType = getOperand(0).getType().cast(); + auto commonType = mlir::cast(getOperand(0).getType()); auto commonShape = commonType.getShape(); int64_t commonRank = commonShape.size(); int64_t axisIndex = getAxis(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConcatFromSequence.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConcatFromSequence.cpp index 17860c1ba2..d8204170a8 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConcatFromSequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConcatFromSequence.cpp @@ -28,10 +28,10 @@ LogicalResult ONNXConcatFromSequenceOp::verify() { return success(); // Won't be able to do any checking at this stage. Value inputSequence = operandAdaptor.getInputSequence(); - assert(inputSequence.getType().isa() && + assert(mlir::isa(inputSequence.getType()) && "Incorrect type for a sequence"); - auto seqType = inputSequence.getType().cast(); - auto elemType = seqType.getElementType().cast(); + auto seqType = mlir::cast(inputSequence.getType()); + auto elemType = mlir::cast(seqType.getElementType()); int64_t rank = elemType.getShape().size(); int64_t axisIndex = getAxis(); int64_t newAxisIndex = getNewAxis(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp index 53f5e429dd..bb16aef01c 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp @@ -30,9 +30,10 @@ LogicalResult ONNXConstantOpShapeHelper::computeShape() { ElementsAttr valAttr; if (operandAdaptor.getSparseValue().has_value()) - valAttr = operandAdaptor.getSparseValueAttr().cast(); + valAttr = + mlir::cast(operandAdaptor.getSparseValueAttr()); else - valAttr = operandAdaptor.getValueAttr().cast(); + valAttr = mlir::cast(operandAdaptor.getValueAttr()); return setOutputDimsFromTypeWithConstantShape(valAttr.getType()); } @@ -90,11 +91,11 @@ LogicalResult ONNXConstantOp::inferShapes( } ElementsAttr valAttr; if (getSparseValue().has_value()) - valAttr = getSparseValueAttr().cast(); + valAttr = mlir::cast(getSparseValueAttr()); else - valAttr = getValueAttr().cast(); + valAttr = mlir::cast(getValueAttr()); Type elementType = - valAttr.getType().cast().getElementType(); + mlir::cast(valAttr.getType()).getElementType(); ONNXConstantOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp index bac82b2255..c34683341f 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp @@ -31,7 +31,7 @@ LogicalResult ONNXConstantOfShapeOpShapeHelper::computeShape() { Value input = operandAdaptor.getInput(); DimsExpr outputDims; - auto inputShape = input.getType().cast().getShape(); + auto inputShape = mlir::cast(input.getType()).getShape(); if (inputShape[0] == 0) { // If 'input' is an empty tensor, the output would be a scalar. // Represent this by an empty outputDims. @@ -56,7 +56,7 @@ LogicalResult ONNXConstantOfShapeOp::verify() { if (!hasShapeAndRank(input)) return success(); - auto inputShape = input.getType().cast().getShape(); + auto inputShape = mlir::cast(input.getType()).getShape(); if (inputShape.size() != 1) return emitOpError("Input tensor must be a 1D tensor"); @@ -69,11 +69,11 @@ LogicalResult ONNXConstantOfShapeOp::verify() { // If the values are valid, it is possible to infer shape. if (auto constantOp = getONNXConstantOp(input)) { ElementsAttr valueAttribute = - constantOp.getValueAttr().cast(); + mlir::cast(constantOp.getValueAttr()); // Get repeat values from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); for (int i = 0; i < inputShape[0]; ++i) { - auto dim = (*valueIt++).cast().getInt(); + auto dim = mlir::cast((*valueIt++)).getInt(); if (dim < 0) return emitOpError("All values of the input tensor must be >=0"); } @@ -115,8 +115,9 @@ LogicalResult ONNXConstantOfShapeOp::inferShapes( // 'value' attribute is a one-element tensor whose value and datatype are // used to set the output tensor value and datatype. if (getValue().has_value()) { - elementType = - getValueAttr().cast().getShapedType().getElementType(); + elementType = mlir::cast(getValueAttr()) + .getShapedType() + .getElementType(); } else { // If 'value' attribute is not specified, it defaults to a tensor of // value 0 and datatype float32. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp index 786717478d..68c9a213a7 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp @@ -71,7 +71,7 @@ LogicalResult ONNXDepthToSpaceOp::verify() { // Won't be able to do any checking at this stage. return success(); } - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto inputShape = inputType.getShape(); if (inputShape.size() != 4) return emitOpError("Input should have a rank of four"); @@ -104,7 +104,8 @@ LogicalResult ONNXDepthToSpaceOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXDepthToSpaceOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp index c7b0a5a323..9e25039d8f 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp @@ -34,14 +34,14 @@ LogicalResult ONNXExpandOpShapeHelper::computeShape() { Value shape = operandAdaptor.getShape(); Operation *shapeDefOp = shape.getDefiningOp(); - ShapedType shapeType = shape.getType().dyn_cast_or_null(); + ShapedType shapeType = mlir::dyn_cast_or_null(shape.getType()); if (!shapeType) return op->emitError("expected shape parameter to be defined"); if (ShapedType::isDynamic(shapeType.getShape()[0])) return op->emitError("expected size of shape parameter to be defined"); if (ONNXShapeOp shapeOp = dyn_cast_or_null(shapeDefOp)) { - assert(shapeOp.getData().getType().isa() && "expected"); + assert(mlir::isa(shapeOp.getData().getType()) && "expected"); // Consider a first case where the expand.shape is produced by a shape op. // Infer its shape and use it as the requested shape. // Compute the output of the shape operation. We have to use its shape @@ -66,7 +66,7 @@ LogicalResult ONNXExpandOpShapeHelper::computeShape() { return success(); } - if (!shape.getType().isa()) + if (!mlir::isa(shape.getType())) return op->emitError("Expecting a shaped type"); SmallVector constVals; createIE->getIntFromArrayAsSymbols(shape, constVals); @@ -87,7 +87,7 @@ LogicalResult ONNXExpandOp::verify() { // Get operands. auto shape = operandAdaptor.getShape(); // Check input. - auto shapeType = shape.getType().dyn_cast_or_null(); + auto shapeType = mlir::dyn_cast_or_null(shape.getType()); if (shapeType && shapeType.hasRank()) { if (shapeType.getRank() != 1) return emitOpError("Shape has a rank of 1"); @@ -104,11 +104,13 @@ LogicalResult ONNXExpandOp::inferShapes( if (!hasShapeAndRank(getInput()) || !hasShapeAndRank(getShape())) return success(); - ShapedType shapeType = getShape().getType().dyn_cast_or_null(); + ShapedType shapeType = + mlir::dyn_cast_or_null(getShape().getType()); if (!shapeType || ShapedType::isDynamic(shapeType.getShape()[0])) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXExpandOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp index c6dad3c26c..18ba80810a 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp @@ -48,7 +48,8 @@ LogicalResult ONNXEyeLikeOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - RankedTensorType inputType = getInput().getType().cast(); + RankedTensorType inputType = + mlir::cast(getInput().getType()); Type elementType; if (getDtypeAttr()) { auto builder = OpBuilder(getContext()); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp index 487a681fa5..5ff505b754 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp @@ -30,7 +30,7 @@ LogicalResult ONNXFlattenOpShapeHelper::computeShape() { ONNXFlattenOpAdaptor operandAdaptor(operands); ONNXFlattenOp flattenOp = llvm::cast(op); Value input = operandAdaptor.getInput(); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputType.getRank(); int64_t axis = flattenOp.getAxis(); @@ -74,7 +74,7 @@ LogicalResult ONNXFlattenOp::verify() { if (!hasShapeAndRank(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); int64_t axisValue = getAxis(); @@ -98,7 +98,8 @@ LogicalResult ONNXFlattenOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXFlattenOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Gather.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Gather.cpp index e957663847..ad032db65c 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Gather.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Gather.cpp @@ -67,7 +67,8 @@ LogicalResult ONNXGatherOp::verify() { if (!hasShapeAndRank(getOperation())) return success(); - auto dataType = operandAdaptor.getData().getType().cast(); + auto dataType = + mlir::cast(operandAdaptor.getData().getType()); ArrayRef dataShape = dataType.getShape(); int64_t dataRank = dataShape.size(); int64_t axisValue = getAxis(); @@ -90,7 +91,8 @@ LogicalResult ONNXGatherOp::inferShapes( if (!hasShapeAndRank(getOperation())) return success(); - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXGatherOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp index 22d43625a4..ce35ad81b3 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp @@ -44,8 +44,8 @@ LogicalResult ONNXGatherElementsOp::verify() { // Get operands and attributes. Value data = operandAdaptor.getData(); Value indices = operandAdaptor.getIndices(); - auto dataType = data.getType().cast(); - auto indicesType = indices.getType().cast(); + auto dataType = mlir::cast(data.getType()); + auto indicesType = mlir::cast(indices.getType()); int64_t dataRank = dataType.getRank(); int64_t indicesRank = indicesType.getRank(); int64_t axis = this->getAxis(); @@ -97,7 +97,8 @@ LogicalResult ONNXGatherElementsOp::inferShapes( if (!hasShapeAndRank(getOperation())) return success(); - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXGatherElementsOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp index eeec664491..7bf23643cd 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp @@ -38,8 +38,8 @@ LogicalResult ONNXGatherNDOpShapeHelper::computeShape() { // int64_t b = op->getBatchDims(); int64_t b = operandAdaptor.getBatchDims(); - assert(indices.getType().isa() && "Expecting a shaped type"); - auto indicesType = indices.getType().cast(); + assert(mlir::isa(indices.getType()) && "Expecting a shaped type"); + auto indicesType = mlir::cast(indices.getType()); ArrayRef indicesShape = indicesType.getShape(); int64_t indicesLastDim = indicesShape[indicesRank - 1]; int64_t outputRank = dataRank + indicesRank - indicesLastDim - 1 - b; @@ -94,8 +94,8 @@ LogicalResult ONNXGatherNDOp::verify() { // Get operands and attributes. Value data = operandAdaptor.getData(); Value indices = operandAdaptor.getIndices(); - auto dataType = data.getType().cast(); - auto indicesType = indices.getType().cast(); + auto dataType = mlir::cast(data.getType()); + auto indicesType = mlir::cast(indices.getType()); int64_t dataRank = dataType.getRank(); int64_t indicesRank = indicesType.getRank(); int64_t b = getBatchDims(); @@ -179,12 +179,13 @@ LogicalResult ONNXGatherNDOp::inferShapes( // Therefore 'indices.shape[-1]' must be known in order to compute the output // shape. ArrayRef indicesShape = - getIndices().getType().cast().getShape(); + mlir::cast(getIndices().getType()).getShape(); int64_t indicesRank = indicesShape.size(); if (indicesShape[indicesRank - 1] == ShapedType::kDynamic) return success(); // cannot infer the output shape yet. - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXGatherNDOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp index a56e4fc3a6..3b1699b35e 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp @@ -76,7 +76,7 @@ LogicalResult ONNXOneHotOp::verify() { Value indices = operandAdaptor.getIndices(); if (hasShapeAndRank(indices)) { // Get rank. - int64_t indicesRank = indices.getType().cast().getRank(); + int64_t indicesRank = mlir::cast(indices.getType()).getRank(); // Verify axis. int64_t axisValue = getAxis(); // Unusually, with a rank of 3, acceptable values are 0 (before first) to 3 @@ -89,7 +89,7 @@ LogicalResult ONNXOneHotOp::verify() { // Check that values is a rank 2 with 2 elements Value values = operandAdaptor.getValues(); if (hasShapeAndRank(values)) { - ShapedType valuesShape = values.getType().cast(); + ShapedType valuesShape = mlir::cast(values.getType()); if (valuesShape.getRank() != 1) return emitOpError("OneHot values must be 1D tensor"); int64_t dim = valuesShape.getDimSize(0); @@ -99,7 +99,7 @@ LogicalResult ONNXOneHotOp::verify() { // Depth is a scalar, check when its a tensor of rank 0 or 1. Value depth = operandAdaptor.getDepth(); if (hasShapeAndRank(depth)) { - ShapedType depthShape = depth.getType().cast(); + ShapedType depthShape = mlir::cast(depth.getType()); if (depthShape.getRank() == 1) { int64_t dim = depthShape.getDimSize(0); if (dim >= 0 && dim != 1) @@ -122,7 +122,8 @@ LogicalResult ONNXOneHotOp::inferShapes( if (!hasShapeAndRank(getIndices())) return success(); - Type elementType = getValues().getType().cast().getElementType(); + Type elementType = + mlir::cast(getValues().getType()).getElementType(); ONNXOneHotOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Optional.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Optional.cpp index 4da033c445..1c2f1f0c0c 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Optional.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Optional.cpp @@ -46,14 +46,14 @@ LogicalResult ONNXOptionalOp::inferShapes( //===----------------------------------------------------------------------===// LogicalResult ONNXOptionalGetElementOp::verify() { - if (!getInput().getType().isa()) + if (!mlir::isa(getInput().getType())) return emitError("OptionalGetElement input should have optional type"); return success(); } LogicalResult ONNXOptionalGetElementOp::inferShapes( std::function doShapeInference) { - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = mlir::cast(getInput().getType()).getElementType(); getResult().setType(elementType); return success(); } @@ -63,7 +63,7 @@ LogicalResult ONNXOptionalGetElementOp::inferShapes( //===----------------------------------------------------------------------===// LogicalResult ONNXOptionalHasElementOp::verify() { - if (!getInput().getType().isa()) + if (!mlir::isa(getInput().getType())) return emitError("OptionalHasElement input should have optional type"); return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp index b00edc4ad5..16e4713a91 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp @@ -74,12 +74,12 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXPadOp::verify() { - ShapedType dataTy = getData().getType().cast(); + ShapedType dataTy = mlir::cast(getData().getType()); Type constTy = getConstantValue().getType(); if (!isNoneValue(getConstantValue())) { // Check that the constant has the same element type as the input - ShapedType shapedConstTy = constTy.cast(); + ShapedType shapedConstTy = mlir::cast(constTy); if (dataTy.getElementType() != shapedConstTy.getElementType()) { return emitOpError("Pad with constant_value that doesn't match the " "element type of the input."); @@ -103,7 +103,8 @@ LogicalResult ONNXPadOp::inferShapes( if (!hasShapeAndRank(getData()) || !hasShapeAndRank(getPads())) return success(); - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXPadOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Range.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Range.cpp index d10a79d59a..a52b6f1534 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Range.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Range.cpp @@ -70,9 +70,9 @@ LogicalResult ONNXRangeOp::verify() { !hasShapeAndRank(getDelta())) return success(); - auto startTensorTy = getStart().getType().cast(); - auto limitTensorTy = getLimit().getType().cast(); - auto deltaTensorTy = getDelta().getType().cast(); + auto startTensorTy = mlir::cast(getStart().getType()); + auto limitTensorTy = mlir::cast(getLimit().getType()); + auto deltaTensorTy = mlir::cast(getDelta().getType()); // Only rank 0 or 1 input tensors are supported. if (startTensorTy.getShape().size() > 1) @@ -125,7 +125,7 @@ LogicalResult ONNXRangeOp::inferShapes( return success(); Type elementType = - getStart().getType().cast().getElementType(); + mlir::cast(getStart().getType()).getElementType(); ONNXRangeOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp index 797f3fcde9..8a9579534e 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp @@ -31,7 +31,7 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() { // Get info about input data operand. Value data = operandAdaptor.getData(); - int64_t dataRank = data.getType().cast().getShape().size(); + int64_t dataRank = mlir::cast(data.getType()).getShape().size(); // Get info about shape operand. Value shape = operandAdaptor.getShape(); @@ -139,7 +139,8 @@ LogicalResult ONNXReshapeOp::inferShapes( if (!hasShapeAndRank(getData()) && !hasStaticShape(getShape().getType())) return success(); - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXReshapeOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp index b804297414..e4e89239c4 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp @@ -124,7 +124,8 @@ LogicalResult ONNXResizeOp::inferShapes( if (!hasShapeAndRank(getX())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXResizeOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ReverseSequence.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ReverseSequence.cpp index 9f0f8277ce..39d902e2d4 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ReverseSequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ReverseSequence.cpp @@ -42,10 +42,10 @@ LogicalResult ONNXReverseSequenceOp::verify() { ONNXReverseSequenceOpAdaptor operandAdaptor = ONNXReverseSequenceOpAdaptor(*this); - auto sequence_lensTy = - operandAdaptor.getSequenceLens().getType().dyn_cast(); + auto sequence_lensTy = mlir::dyn_cast( + operandAdaptor.getSequenceLens().getType()); auto inputTy = - operandAdaptor.getInput().getType().dyn_cast(); + mlir::dyn_cast(operandAdaptor.getInput().getType()); // sequence_lens should be 1D tensor if (sequence_lensTy) { @@ -81,7 +81,8 @@ LogicalResult ONNXReverseSequenceOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXReverseSequenceOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp index 61ca0234e6..279cbecbeb 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp @@ -82,7 +82,7 @@ void ONNXShapeOpShapeHelper::computeSelectedDataShape( // Get rank of data operand. ONNXShapeOpAdaptor operandAdaptor(shapeOp); Value data = operandAdaptor.getData(); - ShapedType shapedType = data.getType().dyn_cast_or_null(); + ShapedType shapedType = mlir::dyn_cast_or_null(data.getType()); assert(shapedType && shapedType.hasRank() && "need shaped type with rank"); int64_t rank = shapedType.getRank(); // Compute the normalized start/end. Negative value means counting diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp index 5d4715f224..88b2cdb8a7 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp @@ -29,7 +29,7 @@ LogicalResult ONNXSliceOpShapeHelper::computeShape() { // Get info about input data operand. ONNXSliceOpAdaptor operandAdaptor(operands); Value data = operandAdaptor.getData(); - uint64_t dataRank = data.getType().cast().getShape().size(); + uint64_t dataRank = mlir::cast(data.getType()).getShape().size(); // Get each of the axes, and save the literal values in axesIntLit. SmallVector axesIntLit; @@ -195,7 +195,8 @@ LogicalResult ONNXSliceOp::inferShapes( } } - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXSliceOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp index 366096dcfd..ce3f4f84d0 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp @@ -73,7 +73,7 @@ LogicalResult ONNXSpaceToDepthOp::verify() { // Won't be able to do any checking at this stage. return success(); } - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto inputShape = inputType.getShape(); if (inputShape.size() != 4) return emitOpError("Input should have a rank of four"); @@ -106,7 +106,8 @@ LogicalResult ONNXSpaceToDepthOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXSpaceToDepthOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp index 708503fa90..be2eeaa887 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp @@ -154,7 +154,7 @@ LogicalResult ONNXSplitOp::verify() { if (!hasShapeAndRank(input)) return success(); // Won't be able to do any checking at this stage. - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); int64_t inputRank = inputType.getShape().size(); int64_t axisIndex = getAxis(); @@ -177,7 +177,7 @@ LogicalResult ONNXSplitOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); Type elementType = inputType.getElementType(); ONNXSplitOpShapeHelper shapeHelper(getOperation(), {}); // Same time for all results. @@ -190,7 +190,7 @@ LogicalResult ONNXSplitV13Op::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); Type elementType = inputType.getElementType(); ONNXSplitV13OpShapeHelper shapeHelper(getOperation(), {}); // Same time for all results. @@ -203,7 +203,7 @@ LogicalResult ONNXSplitV11Op::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - auto inputType = getInput().getType().cast(); + auto inputType = mlir::cast(getInput().getType()); Type elementType = inputType.getElementType(); ONNXSplitV11OpShapeHelper shapeHelper(getOperation(), {}); // Same time for all results. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp index e7f1641ebd..786f1e136a 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp @@ -156,7 +156,7 @@ LogicalResult ONNXSqueezeV11OpShapeHelper::computeShape() { LogicalResult ONNXSqueezeOp::inferShapes( std::function doShapeInference) { - auto dataType = getData().getType().dyn_cast(); + auto dataType = mlir::dyn_cast(getData().getType()); if (!dataType) return success(); @@ -167,7 +167,7 @@ LogicalResult ONNXSqueezeOp::inferShapes( LogicalResult ONNXSqueezeV11Op::inferShapes( std::function doShapeInference) { - auto dataType = getData().getType().dyn_cast(); + auto dataType = mlir::dyn_cast(getData().getType()); if (!dataType) return success(); @@ -203,7 +203,7 @@ OpFoldResult ONNXSqueezeOp::fold(FoldAdaptor adaptor) { "Shape should be static when the inputs are constant"); OnnxElementsAttrBuilder elementsBuilder(getContext()); - return elementsBuilder.reshape(adaptor.getData().cast(), + return elementsBuilder.reshape(mlir::cast(adaptor.getData()), getShape(getSqueezed().getType())); } @@ -222,6 +222,6 @@ OpFoldResult ONNXSqueezeV11Op::fold(FoldAdaptor adaptor) { "Shape should be static when the inputs are constant"); OnnxElementsAttrBuilder elementsBuilder(getContext()); - return elementsBuilder.reshape(adaptor.getData().cast(), + return elementsBuilder.reshape(mlir::cast(adaptor.getData()), getShape(getSqueezed().getType())); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp index eff2a08a5e..96f403c409 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp @@ -61,11 +61,12 @@ LogicalResult ONNXTileOp::inferShapes( return success(); // 'repeats' tensor is an 1D tensor. - auto repeatsTensorTy = getRepeats().getType().cast(); + auto repeatsTensorTy = mlir::cast(getRepeats().getType()); if (repeatsTensorTy.getShape().size() != 1) return emitError("Repeats tensor must have rank one"); - Type elementType = getInput().getType().cast().getElementType(); + Type elementType = + mlir::cast(getInput().getType()).getElementType(); ONNXTileOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp index c1e97a270a..50e8663983 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp @@ -74,7 +74,8 @@ LogicalResult ONNXTransposeOp::inferShapes( if (!hasShapeAndRank(getData())) return success(); - Type elementType = getData().getType().cast().getElementType(); + Type elementType = + mlir::cast(getData().getType()).getElementType(); ONNXTransposeOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp index 2eaefe4562..f9177073e0 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp @@ -65,7 +65,7 @@ LogicalResult ONNXUniqueOp::verify() { return success(); // Too early to verify. // verify axis - int64_t XRank = X.getType().cast().getRank(); + int64_t XRank = mlir::cast(X.getType()).getRank(); std::optional optionalAxis = getAxis(); if (optionalAxis.has_value()) { @@ -87,7 +87,7 @@ LogicalResult ONNXUniqueOp::verify() { LogicalResult ONNXUniqueOp::inferShapes( std::function doShapeInference) { Builder b = Builder(getContext()); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = mlir::cast(getX().getType()).getElementType(); Type indexType = b.getI64Type(); ONNXUniqueOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateTypes( diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp index 7c52888771..5603ae408b 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp @@ -124,7 +124,7 @@ LogicalResult ONNXUnsqueezeV11OpShapeHelper::computeShape() { LogicalResult ONNXUnsqueezeOp::inferShapes( std::function doShapeInference) { - auto dataType = getData().getType().dyn_cast(); + auto dataType = mlir::dyn_cast(getData().getType()); if (!dataType) return success(); @@ -135,7 +135,7 @@ LogicalResult ONNXUnsqueezeOp::inferShapes( LogicalResult ONNXUnsqueezeV11Op::inferShapes( std::function doShapeInference) { - auto dataType = getData().getType().dyn_cast(); + auto dataType = mlir::dyn_cast(getData().getType()); if (!dataType) return success(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp index 2c35b2a156..bdb06ab04a 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp @@ -44,7 +44,7 @@ LogicalResult ONNXUpsampleOpShapeHelper::computeShape() { auto scalesConstOp = getONNXConstantOp(operandAdaptor.getScales()); if (scalesConstOp) { // Can get the scales as constant. - auto valueAttr = scalesConstOp.getValueAttr().dyn_cast(); + auto valueAttr = mlir::dyn_cast(scalesConstOp.getValueAttr()); if (!valueAttr) return op->emitError("Scales constant is not an ElementsAttr"); for (int64_t i = 0; i < xRank; ++i) { @@ -71,11 +71,11 @@ LogicalResult ONNXUpsampleOp::verify() { if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getScales())) return success(); - auto inputTy = getX().getType().cast(); + auto inputTy = mlir::cast(getX().getType()); int32_t inputRank = inputTy.getShape().size(); // Safety checks on scale argument - auto scalesTy = getScales().getType().cast(); + auto scalesTy = mlir::cast(getScales().getType()); if (scalesTy.getShape().size() != 1) { return emitError("Scales tensor must be rank 1"); } @@ -88,7 +88,7 @@ LogicalResult ONNXUpsampleOp::verify() { if (!scalesConstOp) { return success(); } - auto valueAttr = scalesConstOp.getValueAttr().dyn_cast(); + auto valueAttr = mlir::dyn_cast(scalesConstOp.getValueAttr()); if (!valueAttr) { return emitError("Scales constant is not an ElementsAttr"); } @@ -116,7 +116,8 @@ LogicalResult ONNXUpsampleOp::inferShapes( if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getScales())) return success(); - Type elementType = getX().getType().cast().getElementType(); + Type elementType = + mlir::cast(getX().getType()).getElementType(); ONNXUpsampleOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/Transforms/ConstProp.cpp b/src/Dialect/ONNX/Transforms/ConstProp.cpp index b8b5585665..0eb09103c2 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.cpp +++ b/src/Dialect/ONNX/Transforms/ConstProp.cpp @@ -109,7 +109,7 @@ bool isNotDisabled(StringRef name) { ElementsAttr getConstValueElements(Value constValue) { ONNXConstantOp constOp = cast(constValue.getDefiningOp()); - return constOp.getValueAttr().cast(); + return mlir::cast(constOp.getValueAttr()); } // Creates ONNXConstantOp with the location from replacingValue. @@ -292,7 +292,7 @@ constexpr auto subCombiner(Type elemType) { template Value ConstPropElementwiseBinary(PatternRewriter &rewriter, Value replacingValue, Value lhsValue, Value rhsValue) { - auto replacingType = replacingValue.getType().cast(); + auto replacingType = mlir::cast(replacingValue.getType()); ElementsAttr lhs = getConstValueElements(lhsValue); ElementsAttr rhs = getConstValueElements(rhsValue); @@ -311,7 +311,7 @@ template Value ConstPropVariadicElementwiseBinary( PatternRewriter &rewriter, Value replacingValue, ValueRange inputList) { assert(inputList.size() > 0 && "The variadic input is empty"); - auto replacingType = replacingValue.getType().cast(); + auto replacingType = mlir::cast(replacingValue.getType()); Value lhsValue = inputList[0]; if (inputList.size() == 1) @@ -367,7 +367,7 @@ template Value ConstPropElementwiseUnary( PatternRewriter &rewriter, Value replacingValue, Value constValue) { Type replacingElemType = - replacingValue.getType().cast().getElementType(); + mlir::cast(replacingValue.getType()).getElementType(); ElementsAttr constElements = getConstValueElements(constValue); assert(replacingElemType == constElements.getElementType() && @@ -389,7 +389,7 @@ Value ConstPropElementwiseUnary( Value ConstPropWhere(PatternRewriter &rewriter, Value replacingValue, Value condValue, Value lhsValue, Value rhsValue) { - auto replacingType = replacingValue.getType().cast(); + auto replacingType = mlir::cast(replacingValue.getType()); ElementsAttr cond = getConstValueElements(condValue); assert(cond.getElementType().isInteger(1) && @@ -425,9 +425,10 @@ Attribute getIdentity(Builder &builder, Type type) { if constexpr (std::is_same_v) { return builder.getZeroAttr(type); } else if constexpr (std::is_same_v) { - if (auto itype = type.dyn_cast()) + if (auto itype = mlir::dyn_cast(type)) return builder.getIntegerAttr(type, APInt(itype.getWidth(), 1)); - assert(type.isa() && "only supported types are integer, float"); + assert(mlir::isa(type) && + "only supported types are integer, float"); return builder.getFloatAttr(type, 1.0); } else { // Follow NumPy which doesn't support empty tensor for Min, Max, Mean. @@ -451,7 +452,7 @@ Value ConstPropReduceAxesRange(PatternRewriter &rewriter, Value replacingValue, // Find absoluteAxes, converting any negative axes to non-negative. SmallVector absoluteAxes; ElementsAttr data = getConstValueElements(dataValue); - int64_t rank = data.getType().cast().getRank(); + int64_t rank = mlir::cast(data.getType()).getRank(); for (APInt a : axesRange) { int64_t axis = a.getSExtValue(); assert(-rank <= axis && axis < rank && "axis out of range"); @@ -478,7 +479,7 @@ Value ConstPropReduceAxesRange(PatternRewriter &rewriter, Value replacingValue, } else if (data.empty()) { Attribute identity = getIdentity(rewriter, elemType); reduced = DenseElementsAttr::get( - replacingValue.getType().cast(), {identity}); + mlir::cast(replacingValue.getType()), {identity}); } else { bool keepdims = getSIntAttr(op, "keepdims", /*default=*/1) != 0; OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); @@ -716,10 +717,10 @@ Value ConstPropTranspose( PatternRewriter &rewriter, Value replacingValue, Value constValue) { // TODO: figure out if default may be omitted and what to do in that case ArrayAttr permAttr = - replacingValue.getDefiningOp()->getAttr("perm").cast(); + mlir::cast(replacingValue.getDefiningOp()->getAttr("perm")); SmallVector perm; for (auto permVal : permAttr.getValue()) - perm.emplace_back(permVal.cast().getInt()); + perm.emplace_back(mlir::cast(permVal).getInt()); ElementsAttr constElements = getConstValueElements(constValue); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); @@ -853,7 +854,7 @@ Value ConstPropPad(PatternRewriter &rewriter, Value replacingValue, Value data, Value ConstPropConcat(PatternRewriter &rewriter, Value replacingValue, ValueRange operands, IntegerAttr axisAttr) { - ShapedType outputType = replacingValue.getType().cast(); + ShapedType outputType = mlir::cast(replacingValue.getType()); int64_t axis = axisAttr.getValue().getSExtValue(); if (axis < 0) axis += outputType.getRank(); @@ -894,7 +895,7 @@ Value ConstPropGather(PatternRewriter &rewriter, Value replacingValue, ONNXGatherOp gatherOp = cast(op); int64_t axis = gatherOp.getAxis(); if (axis < 0) - axis += inputValue.getType().cast().getRank(); + axis += mlir::cast(inputValue.getType()).getRank(); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ElementsAttr inputElements = getConstValueElements(inputValue); @@ -927,7 +928,7 @@ Value ConstPropConstantOfShape(PatternRewriter &rewriter, Value replacingValue, // ONNXConstantOfShapeOp::inferShapes() makes sure that the 'value' attribute // here is specified - ElementsAttr constElements = value.cast(); + ElementsAttr constElements = mlir::cast(value); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ElementsAttr expandedElements = @@ -942,7 +943,7 @@ Value ConstPropConstantOfShape(PatternRewriter &rewriter, Value replacingValue, Value ConstPropRange(PatternRewriter &rewriter, Value replacingValue, Value start, Value limit, Value delta) { - ShapedType replacingType = replacingValue.getType().cast(); + ShapedType replacingType = mlir::cast(replacingValue.getType()); OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ElementsAttr rangeElements = elementsBuilder.range( @@ -976,7 +977,7 @@ Value ConstPropNonZero( std::vector ConstPropSplit(PatternRewriter &rewriter, ResultRange replacingValues, Value input, Value split, int64_t axis) { unsigned numResults = replacingValues.size(); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); ArrayRef inputShape = inputType.getShape(); int64_t splitAxisSize = inputShape[axis]; diff --git a/src/Dialect/ONNX/Transforms/ConstProp.td b/src/Dialect/ONNX/Transforms/ConstProp.td index 46a4fc1d3f..bd79d13bf8 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.td +++ b/src/Dialect/ONNX/Transforms/ConstProp.td @@ -45,7 +45,7 @@ def HasOneUse : Constraint, "op has exactly one use" def IsNoneType : Constraint(($_self).getType())">>; -def IsIntOrFloatType : Constraint(($_self).getType().cast().getElementType())">>; +def IsIntOrFloatType : Constraint(mlir::cast(($_self).getType()).getElementType())">>; def IsNotAConstant : Constraint(($_self).getDefiningOp())">, @@ -58,7 +58,7 @@ def IsFromDenseONNXConstantOp: def IsFromDenseONNXConstantOpOrNone: Constraint< - CPred<"isDenseONNXConstant($_self) || ($_self.getType().isa())">, + CPred<"isDenseONNXConstant($_self) || mlir::isa($_self.getType())">, "Value is none or produced by a true dense ONNXConstantOp" >; @@ -112,7 +112,7 @@ class EqualString : Constraint>; // Creation helpers: def CreateZeroTensorOfType: NativeCodeCall< - "ConstZeroTensor($_builder, $_loc, $0.getType().cast())" + "ConstZeroTensor($_builder, $_loc, mlir::cast($0.getType()))" >; def CreateAddOfTwoConst : @@ -594,10 +594,10 @@ def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern", def ModConstPropPattern : NamedPat<"ModConstPropPattern", (ONNXModOp:$modOp (ONNXConstantOp:$A $_, $_, $_, $_, $_, $_, $_, $_), - (ONNXConstantOp:$B $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$B $_, $_, $_, $_, $_, $_, $_, $_), $fmod), (CreateModOfTwoConst $modOp, $A, $B), - [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), + [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), (SatisfiesExpansionBound:$modOp)]>; //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/Transforms/ConvOpt.cpp b/src/Dialect/ONNX/Transforms/ConvOpt.cpp index 7aafaa3fcd..827e245522 100644 --- a/src/Dialect/ONNX/Transforms/ConvOpt.cpp +++ b/src/Dialect/ONNX/Transforms/ConvOpt.cpp @@ -45,8 +45,8 @@ bool ExpressONNXConvOpAsMatmul(ONNXConvOp convOp, bool verbose = 0) { return false; if (hasBias && !hasShapeAndRank(B)) return false; - ShapedType xType = X.getType().cast(); - ShapedType wType = W.getType().cast(); + ShapedType xType = mlir::cast(X.getType()); + ShapedType wType = mlir::cast(W.getType()); auto xShape = xType.getShape(); auto wShape = wType.getShape(); int64_t rank = xShape.size(); @@ -146,8 +146,8 @@ struct Conv1x1ToMatmulPattern : public ConversionPattern { Value W = convOp.getW(); Value B = convOp.getB(); bool hasBias = !onnx_mlir::isNoneValue(B); - ShapedType xType = X.getType().cast(); - ShapedType wType = W.getType().cast(); + ShapedType xType = mlir::cast(X.getType()); + ShapedType wType = mlir::cast(W.getType()); Type elementType = xType.getElementType(); auto xShape = xType.getShape(); auto wShape = wType.getShape(); diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 37be6230a2..6747507a12 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -48,24 +48,25 @@ DenseElementsAttr createDenseArrayAttr( PatternRewriter &rewriter, ArrayAttr origAttrs) { assert(origAttrs && "handle EXISTING ArrayAttr only"); - if (origAttrs.getValue()[0].dyn_cast()) { + if (mlir::dyn_cast(origAttrs.getValue()[0])) { Type elementType = rewriter.getF32Type(); int nElements = origAttrs.getValue().size(); SmallVector wrapper(nElements, 0); for (int i = 0; i < nElements; ++i) - wrapper[i] = origAttrs.getValue()[i].cast().getValueAsDouble(); + wrapper[i] = + mlir::cast(origAttrs.getValue()[i]).getValueAsDouble(); return DenseElementsAttr::get( RankedTensorType::get(wrapper.size(), elementType), llvm::ArrayRef(wrapper)); } - if (origAttrs.getValue()[0].dyn_cast()) { + if (mlir::dyn_cast(origAttrs.getValue()[0])) { Type elementType = rewriter.getIntegerType(64); int nElements = origAttrs.getValue().size(); SmallVector wrapper(nElements, 0); for (int i = 0; i < nElements; ++i) - wrapper[i] = origAttrs.getValue()[i].cast().getInt(); + wrapper[i] = mlir::cast(origAttrs.getValue()[i]).getInt(); return DenseElementsAttr::get( RankedTensorType::get(wrapper.size(), elementType), @@ -79,18 +80,18 @@ DenseElementsAttr createDenseArrayAttr( /// This is used to create an ONNXConstant of rank 0, e.g. tensor. DenseElementsAttr createScalarDenseAttr( PatternRewriter &rewriter, Attribute attr) { - if (attr.dyn_cast()) { + if (mlir::dyn_cast(attr)) { Type elementType = rewriter.getF32Type(); SmallVector wrapper; - wrapper.emplace_back(attr.cast().getValueAsDouble()); + wrapper.emplace_back(mlir::cast(attr).getValueAsDouble()); return DenseElementsAttr::get( RankedTensorType::get({}, elementType), llvm::ArrayRef(wrapper)); } - if (attr.dyn_cast()) { + if (mlir::dyn_cast(attr)) { Type elementType = rewriter.getIntegerType(64); SmallVector wrapper; - wrapper.emplace_back(attr.cast().getSInt()); + wrapper.emplace_back(mlir::cast(attr).getSInt()); return DenseElementsAttr::get( RankedTensorType::get({}, elementType), llvm::ArrayRef(wrapper)); } @@ -133,7 +134,7 @@ Value createSequenceConstructOp( Value reverseAllElements( PatternRewriter &rewriter, Location loc, Value input, int64_t dimension) { onnx_mlir::MultiDialectBuilder create(rewriter, loc); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); ArrayRef inputShape = inputType.getShape(); SmallVector sLens; assert((dimension == 0 or dimension == 1) && @@ -154,7 +155,7 @@ Value reverseAllElements( for (int i = 0; i < inputShape[batchAxis]; ++i) sLens.emplace_back(inputShape[timeAxis]); Value sLensVal = create.onnx.constantInt64(sLens); - Type resultType = input.getType().cast(); + Type resultType = mlir::cast(input.getType()); Value result = create.onnx.reverseSequence( resultType, input, sLensVal, batchAxis, timeAxis); return result; @@ -179,7 +180,7 @@ Value reverseAllElements( Value reverseWeightTensor( PatternRewriter &rewriter, Location loc, Value input) { onnx_mlir::MultiDialectBuilder create(rewriter, loc); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); assert(inputType.hasRank() && "Need rank to reverse weight tensor."); // 1. Transpose NxCxD0xD1xD2x... to D0xD1xD2x ... xNxC. @@ -193,7 +194,7 @@ Value reverseWeightTensor( ArrayRef perms(permsVal); Value transposedInput = create.onnx.transposeInt64(input, perms); // 2. Reverse the first and second spatial dimensions. - ShapedType tInputType = transposedInput.getType().cast(); + ShapedType tInputType = mlir::cast(transposedInput.getType()); for (int i = 0; i < spatialRank / 2; i += 2) { // TODO: Support dynamic dim in reverseAllElements(). assert((!tInputType.isDynamicDim(0) && !tInputType.isDynamicDim(1)) && @@ -213,7 +214,7 @@ Value reverseWeightTensor( } // 3. Reverse the rest of dimension if spatial rank is odd. if (spatialRank % 2 != 0) { - ShapedType tInType = transposedInput.getType().cast(); + ShapedType tInType = mlir::cast(transposedInput.getType()); ArrayRef tInShape = tInType.getShape(); Value reverse0; if (tInShape[1] == ShapedType::kDynamic) { @@ -345,7 +346,7 @@ bool shouldDecomposeConvTransposeOp(Value convTransposeResult) { ValueRange emitSplitAxisOutputLength1( PatternRewriter &rewriter, Location loc, Value input, int64_t axis) { onnx_mlir::MultiDialectBuilder create(rewriter, loc); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); ArrayRef inputShape = inputType.getShape(); // Create `split` to split each output in `axis` into length 1. @@ -385,7 +386,7 @@ Value insertPadAxis(PatternRewriter &rewriter, Location loc, Value input, ValueRange padInputs = splitResults.drop_back(); SmallVector padResults; for (Value v : padInputs) { - ArrayRef vShape = v.getType().cast().getShape(); + ArrayRef vShape = mlir::cast(v.getType()).getShape(); padResults.emplace_back( emitPadsAxisEnd(rewriter, loc, v, vShape, axis, padSize)); } @@ -431,7 +432,7 @@ Value insertAdditionalPadsConvTranspose(PatternRewriter &rewriter, Location loc, ONNXConvTransposeOpShapeHelper shapeHelper(op.getOperation(), {}); shapeHelper.computeShapeAndAssertOnFailure(); SmallVector padSize; - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); int64_t spatialOffset = 2; int64_t spatialRank = inputType.getRank() - spatialOffset; DimsExpr outputDims = shapeHelper.getOutputDims(); @@ -447,7 +448,7 @@ Value insertAdditionalPadsConvTranspose(PatternRewriter &rewriter, Location loc, rewriter, loc, input, ArrayRef(inputShape), /*axis*/ 2, padSize[0]); for (int i = 1; i < spatialRank; ++i) { ArrayRef paddedInputShape = - paddedInput.getType().cast().getShape(); + mlir::cast(paddedInput.getType()).getShape(); paddedInput = emitPadsAxisEnd(rewriter, loc, paddedInput, paddedInputShape, /*axis*/ 2 + i, padSize[i]); } @@ -457,21 +458,21 @@ Value insertAdditionalPadsConvTranspose(PatternRewriter &rewriter, Location loc, Value normalizeConstantOp( PatternRewriter &rewriter, Value output, Attribute attr) { - ShapedType outputType = output.getType().cast(); + ShapedType outputType = mlir::cast(output.getType()); Type elementType = outputType.getElementType(); DenseElementsAttr denseAttr; - if (ArrayAttr arrayAttr = attr.dyn_cast()) { + if (ArrayAttr arrayAttr = mlir::dyn_cast(attr)) { int64_t dim = arrayAttr.size(); auto tensorType = RankedTensorType::get({dim}, elementType); denseAttr = DenseElementsAttr::get(tensorType, arrayAttr.getValue()); } else { auto tensorType = RankedTensorType::get({}, elementType); - if (FloatAttr floatAttr = attr.dyn_cast()) { + if (FloatAttr floatAttr = mlir::dyn_cast(attr)) { denseAttr = DenseElementsAttr::get(tensorType, {floatAttr.getValue()}); - } else if (IntegerAttr intAttr = attr.dyn_cast()) { + } else if (IntegerAttr intAttr = mlir::dyn_cast(attr)) { denseAttr = DenseElementsAttr::get(tensorType, intAttr.getSInt()); - } else if (StringAttr strAttr = attr.dyn_cast()) { + } else if (StringAttr strAttr = mlir::dyn_cast(attr)) { denseAttr = DenseElementsAttr::get(tensorType, {strAttr.getValue()}); } else { llvm_unreachable("unexpected Attribute"); @@ -491,7 +492,8 @@ namespace { RankedTensorType createResultType( Type outputType, int64_t axisValue, bool keepDims) { - RankedTensorType outputShapeType = outputType.dyn_cast(); + RankedTensorType outputShapeType = + mlir::dyn_cast(outputType); llvm::ArrayRef shapeVector = outputShapeType.getShape(); int64_t rank = outputShapeType.getRank(); if (axisValue < 0) @@ -732,7 +734,7 @@ struct CustomOpFuseMatMulPattern : public OpRewritePattern { // A must have rank 4 as perm has 4 indices. if (isTransA) { if (onnx_mlir::hasShapeAndRank(A)) { - rankA = A.getType().cast().getRank(); + rankA = mlir::cast(A.getType()).getRank(); } else { if (isa(A)) return false; @@ -749,7 +751,7 @@ struct CustomOpFuseMatMulPattern : public OpRewritePattern { rankA = -1; if (isTransB) { if (onnx_mlir::hasShapeAndRank(B)) { - rankB = B.getType().cast().getRank(); + rankB = mlir::cast(B.getType()).getRank(); } else { if (isa(B)) return false; @@ -790,7 +792,7 @@ struct InstanceNormIntoLayerNormPattern // Get info. Value scale = instanceNormOp.getScale(); Value bias = instanceNormOp.getB(); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); auto inputShape = inputType.getShape(); int64_t C = inputShape[1]; @@ -842,7 +844,7 @@ struct GroupNormIntoLayerNormPattern // Get info. Value scale = groupNormOp.getScale(); Value bias = groupNormOp.getBias(); - ShapedType inputType = input.getType().cast(); + ShapedType inputType = mlir::cast(input.getType()); Type elementType = inputType.getElementType(); auto inputShapeVal = inputType.getShape(); int64_t C = inputShapeVal[1]; @@ -938,7 +940,7 @@ class ReplaceCastLikeByCastPattern : public OpRewritePattern { IntegerAttr saturate = castLikeOp.getSaturateAttr(); // The output type will be the same as the target_type or the second input - Type outputType = target.getType().cast().getElementType(); + Type outputType = mlir::cast(target.getType()).getElementType(); // Replace Value res = onnx_mlir::OnnxBuilder(rewriter, loc) diff --git a/src/Dialect/ONNX/Transforms/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td index f51c296b6a..9f445990a4 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.td +++ b/src/Dialect/ONNX/Transforms/Decompose.td @@ -32,7 +32,7 @@ def KeepdimsTrue "/*isSigned=*/true), APInt(64, 1, /*isSigned=*/true))">; def KeepdimsIsTrue - : Constraint().getSInt() == 1">, + : Constraint($_self).getSInt() == 1">, "keepdims attribute is true">; def ONNXConstantOpNormalize: NativeCodeCall< @@ -42,10 +42,10 @@ def AttributeIsNull : Constraint, "Attribute is null">; def AttributeIsNotNull : Constraint, "Attribute is not null">; -def HasFloatType : Constraint()" +def HasFloatType : Constraint(($_self).getType())" ".getElementType().isF32())">>; -def IsNoneType : Constraint())">>; +def IsNoneType : Constraint(($_self).getType())">>; def GetNullAttr : NativeCodeCall<"Attribute()">; @@ -97,7 +97,7 @@ def createSequenceConstructOp : NativeCodeCall< def ONNXDataType : NativeCodeCall< "IntegerAttr::get($_builder.getIntegerType(64, /*isSigned=*/true), " - "::onnx_mlir::mlirTypeToOnnxType($0.getType().front().cast().getElementType()))">; + "::onnx_mlir::mlirTypeToOnnxType(mlir::cast($0.getType().front()).getElementType()))">; //===----------------------------------------------------------------------===// // ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) @@ -226,7 +226,7 @@ def UpsamplePattern : Pat< (ONNXUpsampleOp $x, $scales, $mode), (ONNXResizeOp $x, (CreateNoneValue), $scales, (CreateNoneValue), (GetNullIntegerAttr), (GetNullArrayAttr), (GetNullStringAttr), - (GetNullFloatAttr), (GetNullIntegerAttr), (GetNullFloatAttr), (GetNullStringAttr), + (GetNullFloatAttr), (GetNullIntegerAttr), (GetNullFloatAttr), (GetNullStringAttr), $mode, (GetNullFloatAttr)) >; @@ -235,7 +235,7 @@ def UpsampleV7Pattern : Pat< (ONNXResizeOp $x, (CreateNoneValue), (ONNXConstantOpFromDenseAttr(createDenseArrayAttr $scales)), (CreateNoneValue), (GetNullIntegerAttr), (GetNullArrayAttr), (GetNullStringAttr), - (GetNullFloatAttr), (GetNullIntegerAttr), (GetNullFloatAttr), (GetNullStringAttr), + (GetNullFloatAttr), (GetNullIntegerAttr), (GetNullFloatAttr), (GetNullStringAttr), $mode, (GetNullFloatAttr)) >; @@ -400,7 +400,7 @@ def ScatterPattern : Pat< >; //===----------------------------------------------------------------------===// -// ONNXReduceL1V13Op %X -> ONNXReduceL1Op %X = +// ONNXReduceL1V13Op %X -> ONNXReduceL1Op %X = //===----------------------------------------------------------------------===// def ReduceL1V13OpPattern1 : Pat<(ONNXReduceL1V13Op $oprd, $axes, $keepdims), @@ -416,7 +416,7 @@ def ReduceL1V13OpPattern2 [], [], (addBenefit 0)>; //===----------------------------------------------------------------------===// -// ONNXReduceL2V13Op %X -> ONNXReduceL2Op %X = +// ONNXReduceL2V13Op %X -> ONNXReduceL2Op %X = //===----------------------------------------------------------------------===// def ReduceL2V13OpPattern1 : Pat<(ONNXReduceL2V13Op $oprd, $axes, $keepdims), @@ -432,7 +432,7 @@ def ReduceL2V13OpPattern2 [], [], (addBenefit 0)>; //===----------------------------------------------------------------------===// -// ONNXReduceLogSumV13Op %X -> ONNXReduceLogSumOp %X = +// ONNXReduceLogSumV13Op %X -> ONNXReduceLogSumOp %X = //===----------------------------------------------------------------------===// def ReduceLogSumV13OpPattern1 : Pat<(ONNXReduceLogSumV13Op $oprd, $axes, $keepdims), @@ -448,7 +448,7 @@ def ReduceLogSumV13OpPattern2 [], [], (addBenefit 0)>; //===----------------------------------------------------------------------===// -// ONNXReduceLogSumExpV13Op %X -> ONNXReduceLogSumExpOp %X = +// ONNXReduceLogSumExpV13Op %X -> ONNXReduceLogSumExpOp %X = //===----------------------------------------------------------------------===// def ReduceLogSumExpV13OpPattern1 : Pat<(ONNXReduceLogSumExpV13Op $oprd, $axes, $keepdims), @@ -464,7 +464,7 @@ def ReduceLogSumExpV13OpPattern2 [], [], (addBenefit 0)>; //===----------------------------------------------------------------------===// -// ONNXReduceSumSquareV13Op %X -> ONNXReduceSumSquareOp %X = +// ONNXReduceSumSquareV13Op %X -> ONNXReduceSumSquareOp %X = //===----------------------------------------------------------------------===// def ReduceSumSquareV13OpPattern1 : Pat<(ONNXReduceSumSquareV13Op $oprd, $axes, $keepdims), @@ -565,7 +565,7 @@ def ConvTransposeOpPattern2: Pattern< def ConstantOfShapePattern: Pat< (ONNXConstantOfShapeOp:$res $shape, $value), - (ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 $value)), + (ONNXExpandOp (ONNXConstantOpFromDenseAttr (ReshapeElementsAttrToRank0 $value)), $shape) >; diff --git a/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp b/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp index 4d86634c49..0c8864c556 100644 --- a/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp +++ b/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp @@ -143,7 +143,7 @@ class Decomposer { const einsum::Signature &signature, ValueRange values) : builder(builder), loc(loc), create(builder, loc) { assert(values.size() >= 1 && "Einsum must have >= 1 inputs"); - elementType = values[0].getType().cast().getElementType(); + elementType = mlir::cast(values[0].getType()).getElementType(); result = signature.output; assert(values.size() == signature.inputs.size() && "Einsum signature inputs (from equation) must match actual inputs"); @@ -532,9 +532,9 @@ class Decomposer { // currently limited to the types supported by ReduceSum and MatMul (which // we decompose to in most cases) which exclude integers with width < 32 bool isDecomposableElementType(Type elementType) { - if (elementType.isa()) + if (mlir::isa(elementType)) return true; - if (IntegerType intType = elementType.dyn_cast()) + if (IntegerType intType = mlir::dyn_cast(elementType)) return intType.getWidth() >= 32; return false; } @@ -551,7 +551,8 @@ LogicalResult DecomposeEinsumPattern::matchAndRewrite( Location loc = einsumOp.getLoc(); ValueRange inputs = einsumOp.getInputs(); - Type elementType = inputs[0].getType().cast().getElementType(); + Type elementType = + mlir::cast(inputs[0].getType()).getElementType(); if (!isDecomposableElementType(elementType)) return rewriter.notifyMatchFailure( loc, "unsupported element type prevents Einsum decomposition"); @@ -589,7 +590,8 @@ LogicalResult DecomposeEinsumPattern::matchAndRewrite( bool DecomposeEinsumPattern::isDecomposable(mlir::ONNXEinsumOp einsumOp) { // TODO: deduplicate repeated logic from matchAndRewrite() ValueRange inputs = einsumOp.getInputs(); - Type elementType = inputs[0].getType().cast().getElementType(); + Type elementType = + mlir::cast(inputs[0].getType()).getElementType(); return isDecomposableElementType(elementType) && llvm::all_of(inputs.getTypes(), hasStaticShape) && hasStaticShape(einsumOp.getOutput().getType()); diff --git a/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp b/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp index 7191871f9f..b5e6826dd9 100644 --- a/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp +++ b/src/Dialect/ONNX/Transforms/ONNXPreKrnlVerifyPass.cpp @@ -62,13 +62,13 @@ class ONNXPreKrnlVerifyPass : public mlir::PassWrapper()) { - auto seqTy = ty.cast(); - if (!seqTy.getElementType().isa()) { + if (mlir::isa(ty)) { + auto seqTy = mlir::cast(ty); + if (!mlir::isa(seqTy.getElementType())) { op.emitError("SeqType with unranked Sequence Element"); return failure(); } - } else if (!ty.isa() && !ty.isa()) { + } else if (!mlir::isa(ty) && !mlir::isa(ty)) { op.emitError("not ranked"); return failure(); } else if (ONNXGatherNDOp gatherNDOp = diff --git a/src/Dialect/ONNX/Transforms/Recompose.cpp b/src/Dialect/ONNX/Transforms/Recompose.cpp index e95bdb3a3d..9d59537a75 100644 --- a/src/Dialect/ONNX/Transforms/Recompose.cpp +++ b/src/Dialect/ONNX/Transforms/Recompose.cpp @@ -237,7 +237,7 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { // Check axes. if (!hasShapeAndRank(dd)) return reportFailure("RMS need rank and shape for input dd"); - int64_t ddRank = dd.getType().cast().getRank(); + int64_t ddRank = mlir::cast(dd.getType()).getRank(); int64_t varAxis; if (!suitableAxis(vReduceOp, ddRank, varAxis)) return reportFailure("RMS unsuitable var reduce axes"); @@ -278,7 +278,7 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { if (hasFullPattern) { if (!hasShapeAndRank(x1)) return reportFailure("LN need rank and shape for input x"); - int64_t x1Rank = x1.getType().cast().getRank(); + int64_t x1Rank = mlir::cast(x1.getType()).getRank(); int64_t meanAxis; if (!suitableAxis(mReduceOp, x1Rank, meanAxis)) hasFullPattern = reportFailure("LN unsuitable mean reduce axes"); diff --git a/src/Dialect/ONNX/Transforms/ShapeInference.cpp b/src/Dialect/ONNX/Transforms/ShapeInference.cpp index 86b109907b..ff29a8734b 100644 --- a/src/Dialect/ONNX/Transforms/ShapeInference.cpp +++ b/src/Dialect/ONNX/Transforms/ShapeInference.cpp @@ -27,7 +27,7 @@ bool hasDynamicOrUnknownShape(Type type) { if (auto tensorType = dyn_cast(type)) return !tensorType.hasStaticShape(); - if (type.isa()) + if (mlir::isa(type)) return false; if (auto seqType = dyn_cast(type)) diff --git a/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp index 0b0d236d4a..efcf1e0cca 100644 --- a/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp +++ b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp @@ -80,7 +80,7 @@ void getDimsInt64(Value val, SmallVectorImpl &result) { getDims(val, dims); for (Value v : dims) { if (auto constOp = dyn_cast(v.getDefiningOp())) { - auto valueAttr = constOp.getValueAttr().cast(); + auto valueAttr = mlir::cast(constOp.getValueAttr()); int64_t dim = valueAttr.getSplatValue(); result.emplace_back(dim); } else { @@ -96,7 +96,7 @@ Value emitConcatOpForDims(MultiDialectBuilder create, if (rank == 1) { // Input is tensor<1xf32>, squeeze it if the output type is scalar i.e. // tensor - if (auto tensorType = outputType.dyn_cast()) { + if (auto tensorType = mlir::dyn_cast(outputType)) { if (tensorType.getRank() == 0) { Value zero = create.onnx.constantInt64({0}); return create.onnx.squeeze(outputType, inputs[0], zero); diff --git a/src/IR/AttrBase.td b/src/IR/AttrBase.td index 1bdbf68b9b..b8f8e1ad5e 100644 --- a/src/IR/AttrBase.td +++ b/src/IR/AttrBase.td @@ -15,15 +15,15 @@ include "mlir/IR/AttrTypeBase.td" -// Attribute for tensor layouts. +// Attribute for tensor layouts. class BaseLayoutAttr traits = []> : AttrDef { let mnemonic = "layout"; } def LayoutAttr : Attr< - CPred<"$_self.isa<::mlir::Attribute>() ">, - //"&& $_self.cast().getMnemonic().equal(\"layout\")">, + CPred<"mlir::isa<::mlir::Attribute>($_self) ">, + //"&& mlir::cast($_self).getMnemonic().equal(\"layout\")">, "layout attribute" >; diff --git a/src/Support/KrnlSupport.cpp b/src/Support/KrnlSupport.cpp index 251d6b3163..5c2f7d1821 100644 --- a/src/Support/KrnlSupport.cpp +++ b/src/Support/KrnlSupport.cpp @@ -123,12 +123,13 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); - } else if (elementType.isa()) { - auto stringType = elementType.cast(); + } else if (mlir::isa(elementType)) { + auto stringType = mlir::cast(elementType); sizeInBits = stringType.getElementSize(); } else { - assert(elementType.isa() && "elementType is not a VectorType"); - auto vectorType = elementType.cast(); + assert(mlir::isa(elementType) && + "elementType is not a VectorType"); + auto vectorType = mlir::cast(elementType); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } @@ -137,7 +138,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { /// Get the size of a static MemRef in bytes. int64_t getMemRefSizeInBytes(Value value) { - MemRefType memRefType = value.getType().dyn_cast(); + MemRefType memRefType = mlir::dyn_cast(value.getType()); auto memRefShape = memRefType.getShape(); int64_t size = 1; for (unsigned int i = 0; i < memRefShape.size(); i++) @@ -150,9 +151,9 @@ int64_t getMemRefSizeInBytes(Value value) { /// If all the dimensions are static, emit a constant. /// Otherwise, emit runtime computations. Value getDynamicMemRefSize(PatternRewriter &rewriter, Location loc, Value val) { - assert( - val.getType().isa() && "Value type should be a MemRefType"); - MemRefType memRefType = val.getType().cast(); + assert(mlir::isa(val.getType()) && + "Value type should be a MemRefType"); + MemRefType memRefType = mlir::cast(val.getType()); auto shape = memRefType.getShape(); // Accumulate static dimensions first. int64_t staticSizeInBytes = 1; @@ -185,9 +186,9 @@ Value getDynamicMemRefSize(PatternRewriter &rewriter, Location loc, Value val) { /// Otherwise, emit runtime computations. Value getDynamicMemRefSizeInBytes( PatternRewriter &rewriter, Location loc, Value val) { - assert( - val.getType().isa() && "Value type should be a MemRefType"); - MemRefType memRefType = val.getType().cast(); + assert(mlir::isa(val.getType()) && + "Value type should be a MemRefType"); + MemRefType memRefType = mlir::cast(val.getType()); auto shape = memRefType.getShape(); // Accumulate static dimensions first. int64_t staticSizeInBytes = getMemRefEltSizeInBytes(memRefType); @@ -255,7 +256,7 @@ Value getDynamicMemRefSizeInBytes(MemRefType type, Location loc, /// int64_t getAllocArgIndex(memref::AllocOp allocOp, int64_t index) { auto memRefShape = - allocOp.getResult().getType().dyn_cast().getShape(); + mlir::dyn_cast(allocOp.getResult().getType()).getShape(); auto rank = memRefShape.size(); int dynDimIdx = 0; diff --git a/src/Support/TypeUtilities.cpp b/src/Support/TypeUtilities.cpp index 1405991527..7d1a05cbb5 100644 --- a/src/Support/TypeUtilities.cpp +++ b/src/Support/TypeUtilities.cpp @@ -27,26 +27,26 @@ Type getElementType(Type ty) { return getElementTypeOrSelf(ty); } /// Check if a type is ShapedType and has rank. bool isRankedShapedType(Type ty) { - return (ty.isa() && ty.cast().hasRank()); + return (mlir::isa(ty) && mlir::cast(ty).hasRank()); } /// Check if a type has static shape. bool hasStaticShape(mlir::Type ty) { if (!isRankedShapedType(ty)) return false; - return ty.cast().hasStaticShape(); + return mlir::cast(ty).hasStaticShape(); } /// Get shape. ArrayRef getShape(Type ty) { assert(isRankedShapedType(ty) && "Type must be ranked"); - return ty.cast().getShape(); + return mlir::cast(ty).getShape(); } /// Get rank. int64_t getRank(Type ty) { assert(isRankedShapedType(ty) && "Type must be ranked"); - return ty.cast().getRank(); + return mlir::cast(ty).getRank(); } /// Get the number of elements. @@ -63,7 +63,7 @@ int64_t getEltSizeInBytes(Type ty) { if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { - auto vectorType = elementType.cast(); + auto vectorType = mlir::cast(elementType); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } diff --git a/test/modellib/ConvModel.cpp b/test/modellib/ConvModel.cpp index efa4416b61..7841e81153 100644 --- a/test/modellib/ConvModel.cpp +++ b/test/modellib/ConvModel.cpp @@ -105,7 +105,8 @@ bool Conv2DLibBuilder::build() { if (failed(res)) return false; - auto outputShape = convOp.getResult().getType().cast().getShape(); + auto outputShape = + mlir::cast(convOp.getResult().getType()).getShape(); modelNOut = outputShape[0]; modelCOut = outputShape[1]; modelHOut = outputShape[2]; @@ -123,7 +124,7 @@ bool Conv2DLibBuilder::build() { bool Conv2DLibBuilder::prepareInputs(float dataRangeLB, float dataRangeUB) { constexpr int num = 2; - OMTensor* list[num]; + OMTensor *list[num]; list[0] = omTensorCreateWithRandomData( {N, CIn, H, W}, dataRangeLB, dataRangeUB); list[1] = omTensorCreateWithRandomData( diff --git a/test/unit/DisposableElementsAttr/TestDisposableElementsAttr.cpp b/test/unit/DisposableElementsAttr/TestDisposableElementsAttr.cpp index 4f6a6ce242..5b21bef540 100644 --- a/test/unit/DisposableElementsAttr/TestDisposableElementsAttr.cpp +++ b/test/unit/DisposableElementsAttr/TestDisposableElementsAttr.cpp @@ -113,9 +113,9 @@ class Test { cpptype one(1); Attribute a = elmsBuilder.toDisposableElementsAttr( DenseElementsAttr::get(type, one)); - ElementsAttr e = a.cast(); + ElementsAttr e = mlir::cast(a); assert(e.isSplat()); - DisposableElementsAttr i = e.cast(); + DisposableElementsAttr i = mlir::cast(e); assert(i.isSplat()); assert(eq(i.getSplatValue(), one)); diff --git a/test/unit/Einsum/TestONNXEinsumOp.cpp b/test/unit/Einsum/TestONNXEinsumOp.cpp index 8c50522724..ef0c616a95 100644 --- a/test/unit/Einsum/TestONNXEinsumOp.cpp +++ b/test/unit/Einsum/TestONNXEinsumOp.cpp @@ -64,9 +64,9 @@ class Test { Type I32; Attribute zero(Type t) { - if (t.isa()) + if (mlir::isa(t)) return FloatAttr::get(t, 0); - assert(t.isa() && "must be IntegerType if not FloatType"); + assert(mlir::isa(t) && "must be IntegerType if not FloatType"); return IntegerAttr::get(t, 0); } @@ -231,7 +231,7 @@ class Test { bool inferenceSuccess = succeeded(op.inferShapes(nullptr)); if (expectSuccess && inferenceSuccess) { auto outputShape = - op.getResult().getType().cast().getShape(); + mlir::cast(op.getResult().getType()).getShape(); if (expectedOutputShape != outputShape) { std::cerr << "inferred output shape " << outputShape << " != expected " << expectedOutputShape << "\n"; diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index a77302388c..78b52deb6f 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -545,11 +545,11 @@ """ let builders = [ OpBuilder<(ins "Attribute":$sparse_value, "Attribute":$value), [{ if (value) { - auto tensorType = value.cast().getType(); + auto tensorType = mlir::cast(value).getType(); build($_builder, $_state, tensorType, sparse_value, value, FloatAttr(), ArrayAttr(), IntegerAttr(), ArrayAttr(), StringAttr(), ArrayAttr()); } else { - auto tensorType = sparse_value.cast().getType(); + auto tensorType = mlir::cast(sparse_value).getType(); build($_builder, $_state, tensorType, sparse_value, value, FloatAttr(), ArrayAttr(), IntegerAttr(), ArrayAttr(), StringAttr(), ArrayAttr()); } @@ -1269,19 +1269,19 @@ def gen_op_def(schema, with_version=False): + ");\n" ) r += ( - "{indent}auto shapedType = resultType.dyn_cast_or_null();\n" + "{indent}auto shapedType = mlir::dyn_cast_or_null(resultType);\n" ) r += "{indent}if (!shapedType || !shapedType.hasStaticShape())\n" r += ( "{indent} resultType = UnrankedTensorType::get(" - + (elTy if elTy else "lhsTy.cast().getElementType()") + + (elTy if elTy else "mlir::cast(lhsTy).getElementType()") + ");\n" ) else: numOperands = 1 r += ( "{indent}auto resultType = UnrankedTensorType::get(" - + "{0}.getType().cast().getElementType());\n" + + "mlir::cast({0}.getType()).getElementType());\n" ) resultType = r