From 98c6971a017460eb9daf1df39d724a7f728f2d13 Mon Sep 17 00:00:00 2001 From: Branko Trifkovic <88882867+BaneTrifa@users.noreply.github.com> Date: Sat, 22 Jun 2024 01:16:38 +0200 Subject: [PATCH 01/30] Implement lowering of torch.aten.triu_indices (#3451) Closes [nod-ai/SHARK-Turbine/issues/709](https://github.com/nod-ai/SHARK-Turbine/issues/709) --------- Co-authored-by: Branko Trifkovic --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 30 ++ lib/Dialect/Torch/IR/TorchOps.cpp | 36 +++ .../Transforms/AbstractInterpLibrary.cpp | 74 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 301 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 35 ++ .../build_tools/torch_ods_gen.py | 5 + .../test_suite/elementwise.py | 60 ++++ 9 files changed, 545 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index dce6018e1a7e3..b836b6bab5b68 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15517,6 +15517,36 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ let hasCanonicalizer = 1; } +def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$row, + Torch_IntType:$col, + Torch_IntType:$offset, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTriuIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenTriuIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b0bb555116f7f..c37b96c60f664 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5212,3 +5212,39 @@ LogicalResult BindSymbolicShapeOp::verify() { return success(); } +// AtenTriuIndicesOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenTriuIndicesOp::verify() { + + // Check if row, col and offset are constant ints + int64_t row; + if (!matchPattern(getRow(), m_TorchConstantInt(&row))) + return success(); + + int64_t col; + if (!matchPattern(getCol(), m_TorchConstantInt(&col))) + return success(); + + int64_t offset; + if (!matchPattern(getOffset(), m_TorchConstantInt(&offset))) + return success(); + + // Check if values of row, and col are valid + if (row < 0) + return emitOpError("row must be non-negative, got ") << row; + + if (col < 0) + return emitOpError("col must be non-negative, got ") << col; + + // Check if dtype is valid + int64_t dtype; + if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) + return success(); + if (dtype != (int)torch_upstream::ScalarType::Int && + dtype != (int)torch_upstream::ScalarType::Long) + return emitOpError( + "'triu_indices' implemented only for torch.int32 and torch.int64"); + + return success(); +} diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 408709816cb8b..e9147d5853eca 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9729,6 +9729,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int, !torch.optional>, !torch.optional) -> !torch.tuple, list, list, list>\n" " return %1 : !torch.tuple, list, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.sub.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %6:2 = torch.prim.If %5 -> (!torch.int, !torch.int) {\n" +" torch.prim.If.yield %int0, %int0 : !torch.int, !torch.int\n" +" } else {\n" +" %11 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %27 = torch.aten.add.int %int1, %3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.prim.min.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %28 : !torch.int\n" +" } else {\n" +" %27 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %29 = torch.aten.Int.bool %28 : !torch.bool -> !torch.int\n" +" torch.prim.If.yield %29 : !torch.int\n" +" }\n" +" %13 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.prim.min.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.prim.max.int %int0, %14 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.prim.min.int %arg0, %16 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.sub.int %15, %12 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.mul.int %21, %20 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.floordiv.int %22, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.sub.int %18, %20 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.mul.int %24, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.prim.max.int %int0, %25 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %23, %26 : !torch.int, !torch.int\n" +" }\n" +" %7 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %6#0, %6#1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %7, %8 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.prim.ListConstruct %int2, %9 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %10 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -14023,6 +14085,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int6 = torch.constant.int 6\n" " return %int6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %int1 = torch.constant.int 1\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a72e583fa9fa6..04f505bea6793 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -732,6 +732,306 @@ class DecomposeAtenTriuOp : public OpRewritePattern { }; } // namespace +/* + This function calculates the number of elements in the lower triangle (below + the main diagonal) of a tensor with dimensions [row, col]. The main diagonal + can be shifted using the 'offset' parameter. The lower triangle is divided into + two parts: a trapezoid and a rectangle. The return tuple includes the number of + elements in the trapezoid, the number of elements in the rectangle, and the + index of the first row such that the element [mFirstRow, 0] is below the main + diagonal. + */ +static std::tuple +getTrilSizes(int64_t row, int64_t col, int64_t offset) { + + // Base case + if (row == 0 || col == 0) { + return std::make_tuple(0, 0, 0); + } + + // Calculate mFirstRow size + int64_t mFirstRow; + if (offset > 0) + mFirstRow = (col < offset + 1) ? col : offset + 1; + else + mFirstRow = (row + offset > 0) ? 1 : 0; + + // Calculate mLastRow size + int64_t minimum = (col < row + offset) ? col : row + offset; + int64_t mLastRow = (minimum > 0) ? minimum : 0; + + // Calculate nRowAll + minimum = (row < row + offset) ? row : row + offset; + int64_t nRowAll = (minimum > 0) ? minimum : 0; + + // Calucltae nRowTrapezoid + int64_t nRowTrapezoid = mLastRow - mFirstRow + 1; + + // Number of elements in top trapezoid - trapezoidSize + int64_t trapezoidSize = (mFirstRow + mLastRow) * nRowTrapezoid / 2; + + // Number of elements in bottom rectangle - rectangleSize + int64_t diffRow = nRowAll - nRowTrapezoid; + int64_t rectangleSize = (diffRow * col > 0) ? diffRow * col : 0; + + // Create return value + return std::make_tuple(trapezoidSize, rectangleSize, mFirstRow); +} + +/* + This function calculates the number of elements in the upper triangle (above + the main diagonal) of a tensor with dimensions [row, col]. The main diagonal + can be shifted using the 'offset' parameter. The upper triangle is divided into + two parts: a trapezoid and a rectangle. The return tuple includes the number of + elements in the trapezoid, the number of elements in the rectangle, and the + index of the first row such that the element [mFirstRow, 0] is above the main + diagonal. + */ +static std::tuple +getTriuSizes(int64_t row, int64_t col, int64_t offset) { + + // Base case + if (row == 0 || col == 0) + return std::make_tuple(0, 0, 0); + + // Calculate mFirstRow size + int64_t maximum = (col - offset > 0) ? col - offset : 0; + int64_t mFirstRow = (offset > 0) ? maximum : col; + + // Number of elements in top rectangle - calculate rectangle size + int64_t minimum = (row < -offset) ? row : -offset; + int64_t rectangleSize = (minimum * col > 0) ? minimum * col : 0; + + // Number of elements in bottom trapezoid - calculte trapezoid size + std::tuple trilSizes = + getTrilSizes(row, col, offset - 1); + int64_t trapezoidSizeTril = std::get<0>(trilSizes); + int64_t rectangleSizeTril = std::get<1>(trilSizes); + + int64_t triuSize = row * col - (trapezoidSizeTril + rectangleSizeTril); + int64_t trapezoidSize = triuSize - rectangleSize; + + // Create return value + return std::make_tuple(trapezoidSize, rectangleSize, mFirstRow); +} + +// decomposition of torch.triu_indices +// https://github.com/pytorch/pytorch/blob/67ef2683d970fc541b6d266d4b3f8ba9d13844ca/torch/_refs/__init__.py#L5829 +namespace { +class DecomposeAtenTriuIndicesOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTriuIndicesOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + // Required parameters + Value row = op.getRow(); + Value col = op.getCol(); + Value offset = op.getOffset(); + + // Check if row, col and offset are constant ints + int64_t rowInt; + if (!matchPattern(row, m_TorchConstantInt(&rowInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: row not constant int"); + + int64_t colInt; + if (!matchPattern(col, m_TorchConstantInt(&colInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: col not constant int"); + + int64_t offsetInt; + if (!matchPattern(offset, m_TorchConstantInt(&offsetInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: offset not constant int"); + + // Optional parameters + Value dtype = op.getDtype(); + Value layout = op.getLayout(); + Value device = op.getDevice(); + Value pinMemory = op.getPinMemory(); + + // Get int value for dtype + int64_t dtypeInt; + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dtype not constant int"); + + FailureOr dtypeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(dtypeType)) + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + + // Constants + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstFalse = rewriter.create(loc, false); + Value cstMinusZeroPointFive = rewriter.create( + loc, rewriter.getF64FloatAttr(-0.5)); + Value cstMinusTwoFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(-2.0)); + + // Calculte trapezoidSize, rectangleSize and mFirstRow + std::tuple triuSizes = + getTriuSizes(rowInt, colInt, offsetInt); + + int64_t trapezoidSizeInt = std::get<0>(triuSizes); + int64_t rectangleSizeInt = std::get<1>(triuSizes); + int64_t mFirstRowInt = std::get<2>(triuSizes); + + // Create const int Values from ints + Value trapezoidSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + Value rectangleSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + Value mFirstRow = rewriter.create( + loc, rewriter.getI64IntegerAttr(mFirstRowInt)); + + // Calculte column offset + Value colOffset = (offsetInt > 0) ? offset : cstZero; + + // Calculate indices for top rectangle + auto arrangeType = + getTensorTypeFromShapeValues({rectangleSize}, *dtypeType); + Value xs2 = + rewriter.create(loc, arrangeType, rectangleSize, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // Calculate row_indices2 and column_idices 2 + Value rowInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + Value colInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + + // Bottom trapezoid + auto f64DtypeInt = + getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); + arrangeType = + getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type()); + Value xs1 = + rewriter.create(loc, arrangeType, trapezoidSize, + /*dtype=*/f64DtypeInt, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // b = -0.5 - m_first_row + Value mFirstRowFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value b = rewriter.create(loc, cstMinusZeroPointFive, + mFirstRowFloat); + + // Implements this piece of code: row_inds1 = torch.floor(-b - torch.sqrt(b + // * b - 2 * xs1)) + Value bSquare = rewriter.create(loc, b, b); + + Value twoTimesXs1 = rewriter.create(loc, xs1.getType(), + xs1, cstMinusTwoFloat); + Value sqrtInput = rewriter.create( + loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + + Value sqrt = + rewriter.create(loc, sqrtInput.getType(), sqrtInput); + Value negativeSqrt = rewriter.create(loc, sqrt.getType(), sqrt); + + Value rowInds1 = rewriter.create( + loc, negativeSqrt.getType(), negativeSqrt, b, cstOne); + rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + + // Implements this piece of code: col_inds1 = torch.floor(xs1 - ((2 * + // m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + Value twoTimesMFirstRow = + rewriter.create(loc, cstTwo, mFirstRow); + twoTimesMFirstRow = + rewriter.create(loc, twoTimesMFirstRow, cstOne); + Value negativeRowInds1 = + rewriter.create(loc, rowInds1.getType(), rowInds1); + + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, twoTimesMFirstRow, + cstOne); + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, rowInds1); + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, + cstMinusZeroPointFive); + + Value colInds1 = rewriter.create(loc, xs1.getType(), xs1, + negativeRowInds1, cstOne); + colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + + // Convert to dtype + Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); + + auto rowInds1Type = cast(rowInds1.getType()); + ArrayRef sizes = rowInds1Type.getSizes(); + Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); + rowInds1 = rewriter.create( + loc, finalRowType, rowInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + auto colInds1Type = cast(colInds1.getType()); + sizes = colInds1Type.getSizes(); + Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); + colInds1 = rewriter.create( + loc, finalColType, colInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + // Final calculation for row and col indices + if (colInt) { + + Value rectangleSizeDivCol = + rewriter.create(loc, rectangleSizeInt / colInt); + + rowInds1 = rewriter.create( + loc, rowInds1.getType(), rowInds1, rectangleSizeDivCol, cstOne); + } + + colInds1 = rewriter.create(loc, colInds1.getType(), + colInds1, colOffset, cstOne); + + Type listElemType = + cast(rowInds1.getType()) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + + Value sequenceRow = rewriter.create( + loc, listType, SmallVector{rowInds2, rowInds1}); + Value sequenceCol = rewriter.create( + loc, listType, SmallVector{colInds2, colInds1}); + + // Concatenate row and col indices + Type finalCatType = colInds1Type.getWithSizesAndDtype( + {rectangleSizeInt + trapezoidSizeInt}, int64Type); + + Value catRow = rewriter.create(loc, finalCatType, sequenceRow, + /*dim=*/cstZero); + Value catCol = rewriter.create(loc, finalCatType, sequenceCol, + /*dim=*/cstZero); + + // Make return value + Value sequence = rewriter.create( + loc, Torch::ListType::get(context, rowInds1.getType()), + ValueRange{catRow, catCol}); + Type finalStackType = colInds1Type.getWithSizesAndDtype( + ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); + + rewriter.replaceOpWithNewOp(op, finalStackType, sequence, + cstZero); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -8399,6 +8699,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 301cb8e809d74..0006a97f44d2d 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -541,6 +541,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8617f1d79534b..fb997435faf71 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1321,6 +1321,9 @@ "TorchPrimLoopForLikeTensorArgModule_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", + "TriuIndicesModule_basic", + "TriuIndicesAllZerosModule_basic", + "TriuIndicesNegativeOffsetModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", "TypeAsSameModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8920de787d5e0..018377e45c169 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1760,6 +1760,38 @@ def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode, per_sample_weights, padding_idx) +@check_shape_function([ + Invocation(4, 3, 1), # Basic case. + Invocation(0, 0, 0), # All zeros case. + Invocation(7, 5, -2), # Negative offset case. + Invocation(35, 55, 16), # Largere values case. +]) +def aten〇triu_indices〡shape(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + if row == 0 or col == 0: + return [2, 0] + + # _get_tril_indices + offset_tril = offset - 1 + if row == 0 or col == 0: + trapezoid_size_tril = 0 + rectangle_size_tril = 0 + else: + m_first_row = min(col, 1 + offset_tril) if offset_tril > 0 else int(row + offset_tril > 0) + m_last_row = max(0, min(col, row + offset_tril)) + n_row_all = max(0, min(row, row + offset_tril)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size_tril = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size_tril = max(0, diff_row * col) + + # Number of elements in bottom trapezoid + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + + return [2, triu_size] + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -4964,6 +4996,9 @@ def aten〇dequantize〇self〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> int: return torch.float32 +def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.int64 if dtype is None else dtype + def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.quint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1ad3b09ee7016..9cf8b26029648 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1065,6 +1065,11 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)") emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True) + emit( + "aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)", + has_verifier=True, + ) + # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f3bcefc95330c..ce000264efec3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -6223,3 +6223,63 @@ def forward(self, x): ) def FakeQuantizePerTensorAffineRoundToEvenModule_basic(module, tu: TestUtils): module.forward(torch.FloatTensor([0.5, 1.5, -0.5, -1.5])) + + +# ============================================================================== + + +class TriuIndicesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(4, 3, 1) + + +@register_test_case(module_factory=lambda: TriuIndicesModule()) +def TriuIndicesModule_basic(module, tu: TestUtils): + module.forward() + + +class TriuIndicesAllZerosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(0, 0, 0) + + +@register_test_case(module_factory=lambda: TriuIndicesAllZerosModule()) +def TriuIndicesAllZerosModule_basic(module, tu: TestUtils): + module.forward() + + +class TriuIndicesNegativeOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(5, 16, -2) + + +@register_test_case(module_factory=lambda: TriuIndicesNegativeOffsetModule()) +def TriuIndicesNegativeOffsetModule_basic(module, tu: TestUtils): + module.forward() From fc19709daab6cd44a29d3b58a7a82ba267ad52b2 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:24:57 -0700 Subject: [PATCH 02/30] [ONNX] Add averagepool dilations support (#3490) - To fix dilations issue: https://github.com/llvm/torch-mlir/issues/3428 - Test by: https://github.com/nod-ai/SHARK-TestSuite/pull/268 --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 53 ++++++++++++------- lib/Conversion/TorchToLinalg/Pooling.cpp | 10 ++++ 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index adde8ceaab402..6932908c05c6b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -379,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; - SmallVector dilation; + SmallVector dilations; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); if (autoPad != "NOTSET") { @@ -387,13 +387,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } - if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) { - return failure(); - } - if (dilation.size() > 0) { - return rewriter.notifyMatchFailure( - binder.op, "dilation is not supported by torch.aten.avgpool op"); - } Torch::ValueTensorType resultType; Value operand; @@ -436,7 +429,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "strides list size does not match the number of axes"); } - SmallVector cstKernel, cstPadding, cstStrides; + SmallVector cstKernel, cstPadding, cstStridesDilations; for (int64_t i : kernel) { cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); @@ -454,9 +447,24 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { - cstStrides.push_back(rewriter.create( + cstStridesDilations.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } + + // No dilations attribute in pytorch avgpool op, so use this trick to + // encode dilation into strides. Then in the following torchtolinalg + // lowering, decode strides into strides + dilation. + // [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...] + if (binder.s64IntegerArrayAttr( + dilations, "dilations", + llvm::SmallVector(rank - 2, 1))) { + return failure(); + } + for (auto dilation : dilations) { + cstStridesDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dilation))); + } + Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), @@ -465,10 +473,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value stridesList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstStrides); + Value stridesDilationsList = + rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstStridesDilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); Value cstCountIncludePad = rewriter.create( @@ -477,19 +487,22 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (rank == 3) { rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad); + binder.op, resultType, operand, kernelSizeList, + stridesDilationsList, paddingList, cstCeilMode, + cstCountIncludePad); return success(); } else if (rank == 4) { rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + binder.op, resultType, operand, kernelSizeList, + stridesDilationsList, paddingList, cstCeilMode, + cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (rank == 5) { rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + binder.op, resultType, operand, kernelSizeList, + stridesDilationsList, paddingList, cstCeilMode, + cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index d80f3d4272e4a..1c3de11079f26 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -612,6 +612,16 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); + // Decode strideInts into strideInts and dilation + if (strideInts.size() == 2 * Dim) { + for (int i = 0; i < Dim; i++) { + dilationInts[i] = strideInts[Dim + i]; + } + for (int i = 0; i < Dim; i++) { + strideInts.pop_back(); + } + } + // TODO: Add support for count_include_pad equal to `False`. bool countIncludePad; if (!matchPattern(op.getCountIncludePad(), From 61f37ae8a39383952d187f0873d24b8f6ccb7bd6 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 24 Jun 2024 15:39:19 +0800 Subject: [PATCH 03/30] [fx importer] support fx importer with lower version torch (#3486) --- python/torch_mlir/extras/fx_importer.py | 42 ++++++++++++++++++------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 2a73325c7d76f..cb86406c55fd7 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -151,11 +151,17 @@ torch.complex32: "complex", torch.complex64: "complex", torch.complex128: "complex", - torch.float8_e5m2: "f8E5M2", - torch.float8_e4m3fn: "f8E4M3FN", - torch.float8_e5m2fnuz: "f8E5M2FNUZ", - torch.float8_e4m3fnuz: "f8E4M3FNUZ", } +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM = { + "float8_e5m2": "f8E5M2", + "float8_e4m3fn": "f8E4M3FN", + "float8_e5m2fnuz": "f8E5M2FNUZ", + "float8_e4m3fnuz": "f8E4M3FNUZ", +} +for dtype_str, dtype_asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, dtype_str)] = dtype_asm TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { torch.float16: lambda: F16Type.get(), @@ -173,11 +179,17 @@ torch.complex32: lambda: ComplexType.get(F16Type.get()), torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), - torch.float8_e5m2: lambda: Float8E5M2Type.get(), - torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(), - torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(), - torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(), } +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE = { + "float8_e5m2": lambda: Float8E5M2Type.get(), + "float8_e4m3fn": lambda: Float8E4M3FNType.get(), + "float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(), + "float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(), +} +for dtype_str, mlir_type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, dtype_str)] = mlir_type TORCH_DTYPE_TO_NPY_TYPE = { # torch.qint8: None, # no equivalent np datatype @@ -215,11 +227,17 @@ # torch.quint8: 13, # torch.qint32 14 torch.bfloat16: 15, - torch.float8_e5m2: 23, - torch.float8_e4m3fn: 24, - torch.float8_e5m2fnuz: 25, - torch.float8_e4m3fnuz: 26, } +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_INT = { + "float8_e5m2": 23, + "float8_e4m3fn": 24, + "float8_e5m2fnuz": 25, + "float8_e4m3fnuz": 26, +} +for dtype_str, dtype_int in OPTIONAL_TORCH_DTYPE_TO_INT.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_INT[getattr(torch, dtype_str)] = dtype_int TORCH_MEMORY_FORMAT_TO_INT = { torch.contiguous_format: 0, From 09f502667b400865843aea90f6f6b6c104969be4 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 24 Jun 2024 15:22:50 -0700 Subject: [PATCH 04/30] `AtenTensorOp::fold` should not fold when result type is not fully specified (#3494) In one of our downstreams, we encountered an internal assertion failure in an intermediate pass from `AtenTensorOp::fold` invocation: ``` external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed. ``` for this snippet in the IR: ``` %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>} ... %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor ``` Turns out this was [fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719) eventually (and we were on an old hash of torch-mlir). This PR submits just the lit test for test coverage on that specific change: ```c++ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { auto resultTy = dyn_cast(getType()); // lit test this if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; ... ``` --- test/Dialect/Torch/canonicalize.mlir | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 250f11cf67a1e..aa943a5a1e5a8 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1534,6 +1534,16 @@ func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { return %67 : !torch.vtensor<[1],si64> } +// CHECK-LABEL: func.func @torch.aten.tensor$no_fold( +// CHECK: torch.aten.tensor %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor +func.func @torch.aten.tensor$no_fold(%arg0: !torch.tensor) -> (!torch.tensor) { + %none = torch.constant.none + %false = torch.constant.bool false + %1 = torch.aten.size %arg0 : !torch.tensor -> !torch.list + %2 = torch.aten.tensor %1, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor + return %2 : !torch.tensor +} + // CHECK-LABEL: func.func @torch.aten.tensor.float( // CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor) : !torch.vtensor<[],f32> func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> { From 3c3fbe4680cdd2725d4dacd59f3bb8a0064220d0 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 25 Jun 2024 12:58:31 +0530 Subject: [PATCH 05/30] [ONNX] Add OnnxToTorch lowering for Onnx.Upsample Op (#3371) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 146 +++++++++++------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 42 +++++ 2 files changed, 136 insertions(+), 52 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6b003b1259c09..63eac34270db5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -152,6 +152,55 @@ LogicalResult reducedSumImpl(OpBinder binder, } return success(); } + +Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter, + Value operand) { + SmallVector itemList; + auto sizes = dyn_cast(operand.getType()).getSizes(); + Torch::BaseTensorType operandType = + cast(operand.getType()); + + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = operandType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); + + auto extract = [&rewriter, &binder](Value x, Value v) { + auto xTy = cast(x.getType()); + Type extractTy = rewriter.getType(); + if (isa(xTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, v); + }; + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + MLIRContext *context = binder.op->getContext(); + for (int i = 2; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value ext = rewriter.create( + binder.getLoc(), selectResultType, operand, zero, selectIndex); + Value item = extract(operand, ext); + itemList.push_back(item); + } + auto xTy = cast(operand.getType()); + Value ValueList; + if (isa(xTy.getDtype())) { + ValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)), + itemList); + } else { + ValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(Torch::FloatType::get(context)), + itemList); + } + return ValueList; +} } // namespace void mlir::torch::onnx_c::populateDefaultDomainQtoZ( @@ -2830,62 +2879,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .getSizes() .size(); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - Value cstFalse = rewriter.create(binder.getLoc(), false); Value cstTrue = rewriter.create(binder.getLoc(), true); Value modeStrValue; - auto extract = [&rewriter, &binder](Value x, Value v) { - auto xTy = cast(x.getType()); - Type extractTy = rewriter.getType(); - if (isa(xTy.getDtype())) - extractTy = rewriter.getType(); - - return rewriter.create(binder.getLoc(), extractTy, - v); - }; - - auto getValueList = [&](Value operand) { - SmallVector itemList; - auto sizes = - dyn_cast(operand.getType()).getSizes(); - Torch::BaseTensorType operandType = - cast(operand.getType()); - - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = operandType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); - - MLIRContext *context = binder.op->getContext(); - for (int i = 2; i < sizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value ext = rewriter.create( - binder.getLoc(), selectResultType, operand, zero, selectIndex); - Value item = extract(operand, ext); - itemList.push_back(item); - } - auto xTy = cast(operand.getType()); - Value ValueList; - if (isa(xTy.getDtype())) { - ValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(context)), itemList); - } else { - ValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::FloatType::get(context)), itemList); - } - return ValueList; - }; - Value scalesValueList = noneVal; Value sizesValueList = noneVal; Value alignCorners = @@ -2934,12 +2933,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (operands.size() < 4) { Value scaleOperand = operands[2]; - scalesValueList = getValueList(scaleOperand); + scalesValueList = getValueList(binder, rewriter, scaleOperand); sizesValueList = noneVal; } else { Value sizeOperand = operands[3]; scalesValueList = noneVal; - sizesValueList = getValueList(sizeOperand); + sizesValueList = getValueList(binder, rewriter, sizeOperand); } if (isa(scalesValueList.getType()) && isa(sizesValueList.getType())) { @@ -3258,4 +3257,47 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, inputSequence); return success(); }); + patterns.onOp( + "Upsample", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + std::string mode; + Value input, scales; + if (binder.tensorOperands(input, scales) || + binder.customOpNameStringAttr(mode, "mode", "nearest") || + binder.tensorResultType(resultType)) { + return failure(); + } + + if (mode != "nearest" && mode != "linear") + return rewriter.notifyMatchFailure( + binder.op, "unsupported interpolation mode other than nearest, " + "linear"); + + int64_t resultRank = resultType.getSizes().size(); + if (resultRank > 5) + return rewriter.notifyMatchFailure( + binder.op, "supports upto 3d upsampling only"); + + Value scalesValueList = getValueList(binder, rewriter, scales); + if (mode == "linear") { + if (resultRank == 4) + mode = "bilinear"; + if (resultRank == 5) + mode = "trilinear"; + } + Value modeStrValue = + rewriter.create(binder.getLoc(), mode); + Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + + rewriter + .replaceOpWithNewOp( + binder.op, resultType, input, /*size=*/cstNone, scalesValueList, + modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ cstNone, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ cstNone, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index ae47b49b06f33..8e37e1d832021 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2541,3 +2541,45 @@ func.func @test_sequence_empty() -> !torch.list> attributes {tor %0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> return %0 : !torch.list> } + +// ----- + +// CHECK-LABEL: func.func @test_upsample_nearest +func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "nearest" + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[UPSAMPLE:.*]] = torch.aten.__interpolate.size_list_scale_list %arg0, %[[NONE]], %[[SCALE_LIST:.*]], %[[MODE]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.str, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,1,4,6],f32> + // CHECK: return %[[UPSAMPLE]] : !torch.vtensor<[1,1,4,6],f32> + %0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> + return %0 : !torch.vtensor<[1,1,4,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_upsample_bilinear +func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "bilinear" + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[UPSAMPLE:.*]] = torch.aten.__interpolate.size_list_scale_list %arg0, %[[NONE]], %[[SCALE_LIST:.*]], %[[MODE]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.str, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,1,4,6],f32> + // CHECK: return %[[UPSAMPLE]] : !torch.vtensor<[1,1,4,6],f32> + %0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> + return %0 : !torch.vtensor<[1,1,4,6],f32> +} From 02340408b7bb909dce71269a031c699c4eb187f5 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 25 Jun 2024 19:00:45 +0530 Subject: [PATCH 06/30] [torch] Add OnnxToTorch lowering for Onnx.STFT op (#3492) Adds OnnxToTorch lowering for `Onnx.STFT` op. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 30 ++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 166 ++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 256 ++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 60 ++++ .../build_tools/torch_ods_gen.py | 3 + .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 49 ++++ 6 files changed, 564 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b836b6bab5b68..c351d845c2f8d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12533,6 +12533,36 @@ def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [ let hasVerifier = 1; } +def Torch_AtenStftOp : Torch_Op<"aten.stft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$n_fft, + AnyTorchOptionalIntType:$hop_length, + AnyTorchOptionalIntType:$win_length, + AnyTorchOptionalTensorType:$window, + Torch_BoolType:$normalized, + AnyTorchOptionalBoolType:$onesided, + AnyTorchOptionalBoolType:$return_complex + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenStftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 63eac34270db5..a6d05d7cc8b80 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3300,4 +3300,170 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*Torch_BoolType:$antialias*/ cstFalse); return success(); }); + patterns.onOp( + "STFT", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // operands in order ->(signal, frameStep, window, frameLength*) + SmallVector operands; + int64_t onesided; + Torch::ValueTensorType resultType; + + if (binder.tensorOperandsList(operands) || + binder.s64IntegerAttr(onesided, "onesided", 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value signal = operands[0]; + Value frameStep = operands[1]; + auto signalTy = cast(signal.getType()); + auto signalShape = signalTy.getSizes(); + auto resultShape = resultType.getSizes(); + + // There are two possible cases for optional inputs frameLength and + // window, which are that either 4 operands will be passed with window + // being !torch.none, or three operands will be passed, with window + // present and frameLength absent. In the former case, we simply create + // a rectangular window consisting of ones, and in the latter, we set + // frameLength equal to the the inputShape[-2] or windowShape[0] + // depending upon whether window was present or not. Note that it is + // possible that both window and frameLength can be none, which would + // mean that either only two operands were passed, or, in case of three + // operands, window was passed in as none, and frameLength was absent. + Value window = nullptr, frameLength = nullptr; + bool windowIsNone = true, frameLengthIsNone = true; + if (operands.size() == 3) { + window = operands[2]; + windowIsNone = isa(window.getType()); + } + if (operands.size() == 4) { + window = operands[2]; + frameLength = operands[3]; + windowIsNone = isa(window.getType()); + frameLengthIsNone = isa(frameLength.getType()); + } + + ArrayRef windowShape; + if (frameLengthIsNone) { + if (windowIsNone) { + frameLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + signalShape[signalShape.size() - 2])); + } else { + windowShape = + cast(window.getType()).getSizes(); + frameLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + } + } + + Value frameLengthItem; + if (!frameLengthIsNone || windowIsNone) { + frameLengthItem = + getItemOp(binder, rewriter, frameLength); + } else { + frameLengthItem = frameLength; + } + Value frameStepItem = + getItemOp(binder, rewriter, frameStep); + + if (windowIsNone) { + auto onesResultTy = rewriter.getType( + ArrayRef({-1}), signalTy.getDtype()); + + Value none = rewriter.create(binder.getLoc()); + Value sizes = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + SmallVector{frameLengthItem}); + window = rewriter.create( + binder.getLoc(), onesResultTy, sizes, none, none, none, none); + } + + FailureOr complexDtype; + if (signalTy.getDtype().isBF16()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support for bfloat16 type is unimplemented."); + } + if (signalTy.getDtype().isF16()) { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexHalf); + } else if (signalTy.getDtype().isF32()) { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexFloat); + } else { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexDouble); + } + + auto complexSignalTy = rewriter.getType( + ArrayRef({signalShape[0], signalShape[1]}), + complexDtype.value()); + + // The onnx STFT op always passes in a float input, and if the input + // is intended to be complex, its shape will be [batch][length][2], + // where [...][0] is the real component, and [...][1] is the complex + // component. This complex input has to be made torch compatible before + // being passed into torch.stft, so it is necessary to call + // AtenViewAsComplexOp. In case of real input, the shape of the signal + // will be [batch][length][1], and therefore it will have to be squeezed + // at dim=2, before being passed into torch.stft. + if (signalShape[2] == 2) { + signal = rewriter.create( + binder.getLoc(), complexSignalTy, signal); + } else { + Value two = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + auto newSignalTy = signalTy.getWithSizesAndDtype( + ArrayRef({signalShape[0], signalShape[1]}), + signalTy.getDtype()); + signal = rewriter.create( + binder.getLoc(), newSignalTy, signal, two); + } + + // In case the window is not given, we use frameLength + // as the length of the window. + Value windowLen; + if (!windowIsNone) { + windowLen = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + } else { + windowLen = frameLengthItem; + } + + Value falseVal = + rewriter.create(binder.getLoc(), false); + Value trueVal = + rewriter.create(binder.getLoc(), true); + auto stftTy = complexSignalTy.getWithSizesAndDtype( + ArrayRef({resultShape[0], resultShape[2], resultShape[1]}), + complexSignalTy.getDtype()); + + // After torch.stft is called and the result is stored into the value + // stft, there is one thing to note: The resultType for the onnx op + // will have shape [batch][num_frames][length][2], while the shape of + // stft will be [batch][length][num_frames]. Before the value is + // converted to real through torch.view_as_real, we must permute the + // shape of stft to match the shape of resultType. Also, it is + // immaterial whether torch.view_as_real is called after or before the + // permutation; both outputs will be equivalent. + Value stft = rewriter.create( + binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem, + windowLen, window, falseVal, onesided ? trueVal : falseVal, + trueVal); + + auto permuteStftTy = complexSignalTy.getWithSizesAndDtype( + ArrayRef({resultShape[0], resultShape[1], resultShape[2]}), + complexSignalTy.getDtype()); + Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1}); + Value permutedStft = rewriter.create( + binder.getLoc(), permuteStftTy, stft, permuteDims); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, permutedStft); + return success(); + }); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e9147d5853eca..537d3b6198a46 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10143,6 +10143,125 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.optional) {\n" +" %24 = torch.derefine %none : !torch.none to !torch.optional\n" +" torch.prim.If.yield %24 : !torch.optional\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.derefine %24 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %25 : !torch.optional\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %9 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %24 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" %24 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %24 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.prim.ListConstruct : () -> !torch.list\n" +" %15 = torch.aten.__isnot__ %5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" %24 = torch.prim.unchecked_cast %5 : !torch.optional -> !torch.int\n" +" %25 = torch.aten.append.t %14, %24 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.__is__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.bool\n" +" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %17 -> () {\n" +" %24 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %14, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %24 = torch.aten.append.t %14, %arg1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.floordiv.int %18, %10 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %int1, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %14, %20 : !torch.list, !torch.int -> !torch.list\n" +" %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %24 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" %24 = torch.aten.append.t %14, %int2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %14 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" @@ -11607,6 +11726,143 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" +" %none = torch.constant.none\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %7 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %7 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %11 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %11 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n" +" %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n" +" } else {\n" +" %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %15 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n" +" %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n" +" } else {\n" +" %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %19 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %19 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 018377e45c169..e77a1978b1019 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1976,6 +1976,35 @@ def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self +@check_shape_function([ + Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window. +]) +def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Optional[List[int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> List[int]: + assert len(self) == 1 or len(self) == 2, "Expected input tensor to be of shape (B?,L), where B is an optional batch dimension" + + batch = None if len(self) == 1 else self[0] + length = self[0] if len(self) == 1 else self[1] + hop_length = (n_fft // 4) if hop_length is None else hop_length + assert n_fft > 0 and n_fft <= length, "Expected that 0 < n_fft <= len" + assert hop_length > 0, "Expected hop_length to be greater than 0" + + out: List[int] = [] + if batch is not None: + out.append(batch) # (B?,) + + if onesided is None or onesided == True: + out.append(n_fft//2 + 1) + else: + out.append(n_fft) # (B?,N,) + + # For this operator, center=False by default + out.append(1 + (length - n_fft)//hop_length) #(B?,N,T,) + + if return_complex is not None and bool(return_complex) == False: + out.append(2) # a length-2 dimension of real and imaginary components. This gives output shape (B?,N,T,C?). + + return out + class DummyClassType: def __init__(self): pass @@ -3307,6 +3336,37 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = else: assert False, "Unsupported dtype" +@check_dtype_function([ + Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32 + Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64 + Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=True), # output dtype = torch.complex64 + Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=False), # output dtype = torch.float32 +]) +def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window_rank_dtype: Optional[Tuple[int, int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype) and return_complex is not None and return_complex: + return self_dtype + elif is_complex_dtype(self_dtype) and return_complex is not None and return_complex != True: + if self_dtype == torch.complex32: + return torch.float16 + elif self_dtype == torch.complex64: + return torch.float32 + elif self_dtype == torch.complex128: + return torch.float64 + elif is_float_dtype(self_dtype) and return_complex is not None and return_complex: + if self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_float_dtype(self_dtype) and return_complex is not None and return_complex != True: + return self_dtype + elif is_integer_dtype(self_dtype): + return torch.complex64 + + assert False, "Unsupported dtype" + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 9cf8b26029648..b21362f7c8ef6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -921,6 +921,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)", has_verifier=True, ) + emit( + "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)" + ) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8e37e1d832021..445d54c8697fc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2583,3 +2583,52 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: ! %0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> return %0 : !torch.vtensor<[1,1,4,6],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_stft +func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list + // CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32> + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_stft_with_window +func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16 + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16 + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} From e346c911f7f2f21d59f0ed4fb01059aba540d7a9 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 25 Jun 2024 11:02:45 -0500 Subject: [PATCH 07/30] [ONNX] Add basic support for RoiAlign (#3493) This adds an onnx->torch conversion for onnx.RoiAlign into torchvision.roi_align or torchvision.roi_pool, and adds those two torchvision ops to torch-mlir. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 57 +++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 98 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 29 ++++++ .../build_tools/abstract_interp_lib_gen.py | 15 +++ .../build_tools/torch_ods_gen.py | 9 ++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 31 ++++++ 6 files changed, 239 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c351d845c2f8d..bab7131f72382 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16660,3 +16660,60 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ }]; } +def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$rois, + Torch_FloatType:$spatial_scale, + Torch_IntType:$pooled_height, + Torch_IntType:$pooled_width, + Torch_IntType:$sampling_ratio, + Torch_BoolType:$aligned + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionRoiAlignOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void TorchvisionRoiAlignOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$rois, + Torch_FloatType:$spatial_scale, + Torch_IntType:$pooled_height, + Torch_IntType:$pooled_width + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionRoiPoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void TorchvisionRoiPoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a6d05d7cc8b80..58d8397ee67c0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2953,6 +2953,104 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*Torch_BoolType:$antialias*/ cstFalse); return success(); }); + patterns.onOp( + "RoiAlign", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // operands = input, rois, batch_indices + SmallVector operands; + std::string coordTfMode, mode; + int64_t outHInt, outWInt, samplingRatioInt; + float spatialScaleFloat; + Torch::ValueTensorType resultType; + if (binder.tensorOperands(operands, 3) || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(mode, "mode", "avg") || + binder.s64IntegerAttr(outHInt, "output_height", 1) || + binder.s64IntegerAttr(outWInt, "output_width", 1) || + binder.s64IntegerAttr(samplingRatioInt, "sampling_ratio", 0) || + binder.f32FloatAttr(spatialScaleFloat, "spatial_scale", 1.0f) || + binder.tensorResultType(resultType)) + return failure(); + Value input = operands[0]; + Value rois = operands[1]; + Value batchIndices = operands[2]; + + // the torchvision roi_pool op does not support these features: + if (mode == "max" && + (coordTfMode != "half_pixel" || samplingRatioInt != 0)) + return rewriter.notifyMatchFailure( + binder.op, "unsupported: roi max pooling without default " + "coordTfMode and sampling_ratio"); + + Location loc = binder.getLoc(); + // concatenate the batchIndices to the rois to get rois as a num_roisx5 + // tensor. The batchIndices tensor is an int64 tensor, and needs to be + // converted to float before concatenation. + auto roisType = dyn_cast(rois.getType()); + if (!roisType || !roisType.hasSizes()) + return failure(); + Value cstDim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + FailureOr unsqueezeIndices = + Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim); + if (failed(unsqueezeIndices)) + return failure(); + batchIndices = unsqueezeIndices.value(); + auto batchIndicesType = + cast(batchIndices.getType()); + Value dTypeInt = + Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype()); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value newBatchIndices = rewriter.create( + loc, + batchIndicesType.getWithSizesAndDtype( + batchIndicesType.getOptionalSizes(), + roisType.getOptionalDtype()), + batchIndices, dTypeInt, cstFalse, cstFalse, none); + SmallVector roiSizes(roisType.getSizes()); + roiSizes.back() = 5; + auto catType = rewriter.getType( + roiSizes, roisType.getDtype()); + Type listElemType = + roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois}); + Value newRois = + rewriter.create(loc, catType, tensorList, cstDim); + + // make constants from attributes + Value cstSpatialScale = rewriter.create( + loc, rewriter.getF64FloatAttr(spatialScaleFloat)); + Value pooledHeight = rewriter.create( + loc, rewriter.getI64IntegerAttr(outHInt)); + Value pooledWidth = rewriter.create( + loc, rewriter.getI64IntegerAttr(outWInt)); + // this is for consistency with the default pytorch sampling ratio value + samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt; + Value samplingRatio = rewriter.create( + loc, rewriter.getI64IntegerAttr(samplingRatioInt)); + bool aligned = coordTfMode == "half_pixel"; + Value cstAligned = rewriter.create(loc, aligned); + + if (mode == "avg") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, newRois, cstSpatialScale, + pooledHeight, pooledWidth, samplingRatio, cstAligned); + return success(); + } + // mode == "max" + auto indicesType = resultType.getWithSizesAndDtype( + resultType.getOptionalSizes(), batchIndicesType.getDtype()); + auto roiPool = rewriter.create( + loc, TypeRange{resultType, indicesType}, input, newRois, + cstSpatialScale, pooledHeight, pooledWidth); + rewriter.replaceOp(binder.op, roiPool.getResult(0)); + return success(); + }); patterns.onOp( "SpaceToDepth", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 537d3b6198a46..69d48fa3c0d5d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6256,6 +6256,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.roi_align\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_align\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.roi_pool\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %3 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_pool\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n" " %true = torch.constant.bool true\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e77a1978b1019..97fe12255a800 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -8,6 +8,7 @@ import os import torch +import torchvision from torch import device import torch.jit._shape_functions as upstream_shape_functions @@ -85,6 +86,20 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) + +def torchvision〇roi_align〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> List[int]: + return [rois[0], input[1], pooled_height, pooled_width] + +def torchvision〇roi_align〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> int: + return input_rank_dtype[1] + +def torchvision〇roi_pool〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[List[int], List[int]]: + output = [rois[0], input[1], pooled_height, pooled_width] + return (output, output) + +def torchvision〇roi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]: + return (input_rank_dtype[1], torch.int64) + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`. diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index b21362f7c8ef6..401e7bef20c16 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1155,6 +1155,13 @@ def emit_with_mutating_variants(key, **kwargs): traits=["HasValueSemantics"], ) + emit( + "torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)" + ) + emit( + "torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)" + ) + def dump_registered_ops(outfile: TextIO, registry: Registry): for _, v in sorted(registry.by_unique_key.items()): @@ -1173,6 +1180,8 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args: argparse.Namespace): _maybe_import_op_extensions(args) + import torchvision + registry = Registry.load() if args.debug_registry_dump: with open(args.debug_registry_dump, "w") as debug_registry_dump: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 445d54c8697fc..d611823f9052d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2207,6 +2207,37 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve // ----- +// CHECK-LABEL: @test_roialign_avg + func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[Dim:.*]] = torch.constant.int 1 + // CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]] + // CHECK: %[[cst6:.*]] = torch.constant.int 6 + // CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]] + // CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1 + // CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]] + // CHECK: %[[Align:.*]] = torch.torchvision.roi_align %arg0, %[[Cat]] + %0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "output_half_pixel", torch.onnx.mode = "avg", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> + return %0 : !torch.vtensor<[30,2,5,5],f32> + } + +// ----- + +// CHECK-LABEL: @test_roialign_max + func.func @test_roialign_max(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[Dim:.*]] = torch.constant.int 1 + // CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]] + // CHECK: %[[cst6:.*]] = torch.constant.int 6 + // CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]] + // CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1 + // CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]] + // CHECK: %[[Pool:.*]], %[[Indices:.*]] = torch.torchvision.roi_pool %arg0, %[[Cat]] + // CHECK: return %[[Pool]] + %0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "max", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> + return %0 : !torch.vtensor<[30,2,5,5],f32> + } + +// ----- + // CHECK-LABEL: @test_spacetodepth_example func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 From 368fabf0c1a691fe4bdac1b6d6c1011c45eccf21 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 25 Jun 2024 12:16:51 -0500 Subject: [PATCH 08/30] [ONNX] Basic Support for DeformConv (#3469) This adds a torchvision op to torch-mlir and a path from onnx.DeformConv to torchvision.deform_conv2d. I'm not implementing the torch->linalg lowering for the torchvision op yet, but posting this PR to get feedback on some of the choices being made here and to flesh out the onnx frontend a bit. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 36 +++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 135 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 16 +++ projects/pt1/e2e_testing/xfail_sets.py | 24 +++- .../build_tools/abstract_interp_lib_gen.py | 10 +- .../build_tools/torch_ods_gen.py | 8 ++ .../configs/onnx_backend.py | 8 +- .../torch_mlir_e2e_test/test_suite/conv.py | 87 +++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 13 ++ 9 files changed, 328 insertions(+), 9 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index bab7131f72382..4b2ba6defa027 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16660,6 +16660,42 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ }]; } +def Torch_TorchvisionDeformConv2dOp : Torch_Op<"torchvision.deform_conv2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$offset, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$bias, + Torch_IntType:$stride_h, + Torch_IntType:$stride_w, + Torch_IntType:$pad_h, + Torch_IntType:$pad_w, + Torch_IntType:$dilation_h, + Torch_IntType:$dilation_w, + Torch_IntType:$groups, + Torch_IntType:$offset_groups, + Torch_BoolType:$use_mask + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionDeformConv2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 14, 1); + } + void TorchvisionDeformConv2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 14, 1); + } + }]; +} + def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 6932908c05c6b..c89452ad6cb35 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1837,6 +1837,141 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, transposedInput, reshapeSizesList); return success(); }); + patterns.onOp( + "DeformConv", 19, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + auto loc = binder.getLoc(); + + // get operands + llvm::SmallVector operands; + Torch::ValueTensorType resultType; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType)) + return failure(); + if (operands.size() < 3 || operands.size() > 5) + return failure(); + auto inputType = + dyn_cast(operands[0].getType()); + if (!inputType || !inputType.hasSizes() || + inputType.getSizes().size() != 4) + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: DeformConv with input rank != 4"); + unsigned rank = inputType.getSizes().size(); + auto weightType = + dyn_cast(operands[1].getType()); + if (!weightType || !weightType.hasSizes()) + return failure(); + auto offsetType = + dyn_cast(operands[2].getType()); + if (!offsetType || !offsetType.hasSizes()) + return failure(); + + // get attributes + SmallVector dilations, kernelShape, pads, strides; + SmallVector defaultDilations(rank - 2, 0); + SmallVector defaultPads(2 * (rank - 2), 0); + SmallVector defaultStrides(rank - 2, 1); + int64_t group, offsetGroup; + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations) || + binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}) || + binder.s64IntegerArrayAttr(pads, "pads", defaultPads) || + binder.s64IntegerArrayAttr(strides, "strides", defaultStrides) || + binder.s64IntegerAttr(group, "group", 1) || + binder.s64IntegerAttr(offsetGroup, "offset_group", 1)) + return failure(); + + for (unsigned i = 0; i < rank - 2; i++) { + if (pads[i] != pads[rank + i - 2]) + return rewriter.notifyMatchFailure( + binder.op, "unsupported: asymmetric padding"); + } + + // Identify and assign names to operands + Value input, weight, offset, bias, mask; + bool useMask = false; + input = operands[0]; + weight = operands[1]; + offset = operands[2]; + if (operands.size() == 4) { + auto unknownOpdRank = Torch::getTensorRank(operands[3]); + if (!unknownOpdRank) + return failure(); + if (*unknownOpdRank == 1) + bias = operands[3]; + else if (*unknownOpdRank == rank) { + mask = operands[3]; + useMask = true; + } else + llvm_unreachable("onnx.DeformConv: optional 4th operand of " + "unexpected rank encountered"); + } + if (operands.size() == 5) { + bias = operands[3]; + mask = operands[4]; + useMask = true; + } + + // assign default operand values if necessary + ArrayRef weightSizes = weightType.getSizes(); + ArrayRef offsetSizes = offsetType.getSizes(); + if (!bias) { + int64_t outputChannels = weightSizes[0]; + SmallVector biasShape(1, outputChannels); + Value biasShapeList = mlir::torch::onnx_c::createConstantIntList( + binder, rewriter, biasShape); + Value cstZero = Torch::getConstantWithGivenDtypeAndValue( + rewriter, loc, 0.0f, inputType.getDtype()); + bias = + Torch::createInitTensor(rewriter, loc, + rewriter.getType( + biasShape, inputType.getDtype()), + cstZero, biasShapeList); + } + if (!mask) { + int64_t batchSize = inputType.getSizes()[0]; + int64_t kernelHeight = weightSizes[2]; + int64_t kernelWidth = weightSizes[3]; + int64_t outputHeight = offsetSizes[2]; + int64_t outputWidth = offsetSizes[3]; + int64_t maskDimOne = offsetGroup * kernelHeight * kernelWidth; + SmallVector maskShape( + {batchSize, maskDimOne, outputHeight, outputWidth}); + Value cstOne = Torch::getConstantWithGivenDtypeAndValue( + rewriter, loc, 1.0f, inputType.getDtype()); + Value maskShapeList = mlir::torch::onnx_c::createConstantIntList( + binder, rewriter, maskShape); + mask = + Torch::createInitTensor(rewriter, loc, + rewriter.getType( + maskShape, inputType.getDtype()), + cstOne, maskShapeList); + } + + // get attributes as constant values + SmallVector dilationValues, padValues, strideValues; + for (auto i : dilations) + dilationValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + for (auto i : pads) + padValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + for (auto i : strides) + strideValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + Value groupValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(group)); + Value offsetGroupValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(offsetGroup)); + Value useMaskValue = rewriter.create( + loc, rewriter.getBoolAttr(useMask)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, offset, mask, bias, + strideValues[0], strideValues[1], padValues[0], padValues[1], + dilationValues[0], dilationValues[1], groupValue, offsetGroupValue, + useMaskValue); + return success(); + }); patterns.onOp( "DequantizeLinear", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 69d48fa3c0d5d..e94d3bd7c9df6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9492,6 +9492,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.deform_conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg2, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.deform_conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.tuple, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fb997435faf71..35a34e2b10688 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,6 +29,9 @@ "InterpolateDynamicModule_scales_recompute_bilinear", "ElementwiseFloatTensorGtIntTensorModule_basic", "AtenIntMM_basic", + # unimplemented lowering torch -> linalg for torchvision.deform_conv2d + # this is added to check the torch.onnx.export -> import_onnx -> torch path + "DeformConv2D_basic", } LINALG_CRASHING_SET = { @@ -383,6 +386,7 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", @@ -554,6 +558,7 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "DeformConv2D_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -2357,19 +2362,12 @@ "DivIntModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", - "ElementwiseAndScalarModule_basic", - "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic", - "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndScalarInt32Module_basic", - "ElementwiseBitwiseAndScalarInt64Module_basic", - "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrModule_basic", @@ -2710,6 +2708,8 @@ "IndexPutHackedTwin3DIntNonAccumulateModule_basic", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", + # unimplemented torchvision.deform_conv2d torch->linalg + "DeformConv2D_basic", # Error: 'aten::renorm' to ONNX opset version 17 is not supported. "RenormModuleFloat16_basic", "RenormModuleFloat32NegativeDim_basic", @@ -2759,6 +2759,14 @@ "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", "ElementwiseBitwiseLeftShiftInt8Module_basic", + # bitwise and support has been added in torch nightly + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", } if torch_version_for_comparison() < version.parse("2.4.0.dev"): @@ -2930,6 +2938,7 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "DeformConv2D_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -3724,6 +3733,7 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "DeformConv2D_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 97fe12255a800..1f70a42ce8eeb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -8,7 +8,6 @@ import os import torch -import torchvision from torch import device import torch.jit._shape_functions as upstream_shape_functions @@ -1639,6 +1638,12 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: assert False, "Unsupported dtype" +def torchvision〇deform_conv2d〡shape(input: List[int], weight: List[int], offset: List[int], mask: List[int], bias: List[int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> List[int]: + return [input[0], weight[0], offset[2], offset[3]] + +def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], offset_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> int: + return input_rank_dtype[1] + def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) @@ -5117,6 +5122,9 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args): _maybe_import_op_extensions(args) + # importing torchvision will register torchvision ops with the JITOperatorRegistry + import torchvision + asm = generate_library(globals()) # We're about to put quotes around the string, so escape the `"` characters. asm = asm.replace("\"", "\\\"") diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 401e7bef20c16..7c3f79ef44297 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1155,6 +1155,13 @@ def emit_with_mutating_variants(key, **kwargs): traits=["HasValueSemantics"], ) + # ========================================================================== + # `torchvision::` namespace. + # ========================================================================== + + emit( + "torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)" + ) emit( "torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)" ) @@ -1180,6 +1187,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args: argparse.Namespace): _maybe_import_op_extensions(args) + # importing torchvision will register torchvision ops with the JITOperatorRegistry import torchvision registry = Registry.load() diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index fb9b2712d3197..fc0d488b4787a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -9,6 +9,7 @@ import io import onnx import torch +from torch.onnx._constants import ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET as max_opset_ver import torch_mlir from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -78,7 +79,12 @@ def convert_onnx(model, inputs): examples = tuple(examples) torch.onnx.export( - model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors + model, + examples, + buffer, + input_names=input_names, + dynamic_axes=dynamic_tensors, + opset_version=max_opset_ver, ) buffer = buffer.getvalue() return import_onnx(buffer) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index af8bea091d08f..2e00e2079cb35 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1256,3 +1256,90 @@ def ConvTranspose2DQInt8_basic(module, tu: TestUtils): tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), torch.rand(Cout), ) + + +# torchvision.deform_conv2d + +import torchvision + +# This section defines a torch->onnx path for this torchvision op so we can test the onnx paths e2e. + +# Create symbolic function +from torch.onnx.symbolic_helper import parse_args, _get_tensor_sizes + + +@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "b") +def symbolic_deform_conv2d_forward( + g, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask, +): + args = [input, weight, offset, bias] + if use_mask: + args.append(mask) + weight_size = _get_tensor_sizes(weight) + kwargs = { + "dilations_i": [dilation_h, dilation_w], + "group_i": groups, + "kernel_shape_i": weight_size[2:], + "offset_group_i": offset_groups, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": [pad_h, pad_w, pad_h, pad_w], + "strides_i": [stride_h, stride_w], + } + return g.op("DeformConv", *args, **kwargs) + + +# Register symbolic function +from torch.onnx import register_custom_op_symbolic + +register_custom_op_symbolic( + "torchvision::deform_conv2d", symbolic_deform_conv2d_forward, 19 +) + +N = 1 +Cin = 1 +Hin = 7 +Win = 6 +Cout = 1 +Hker = 2 +Wker = 2 +offset_groups = 1 +Hout = 6 +Wout = 5 +offset_dim1 = 2 * offset_groups * Hker * Wker + + +class DeformableConvModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([N, Cin, Hin, Win], torch.float32, True), + ([N, offset_dim1, Hout, Wout], torch.float32, True), + ([Cout, Cin, Hker, Wker], torch.float32, True), + ] + ) + def forward(self, input, offset, weight): + return torchvision.ops.deform_conv2d(input, offset, weight) + + +@register_test_case(module_factory=lambda: DeformableConvModule()) +def DeformConv2D_basic(module, tu: TestUtils): + input = tu.rand(N, Cin, Hin, Win) + offset = tu.rand(N, offset_dim1, Hout, Wout) + weight = tu.rand(Cout, Cin, Hker, Wker) + module.forward(input, offset, weight) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 74793852de4a8..4b03fcceeec18 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -735,6 +735,19 @@ func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4 // ----- +// CHECK-LABEL: @test_deform_conv +func.func @test_deform_conv(%arg0: !torch.vtensor<[1,1,7,6],f32>, %arg1: !torch.vtensor<[1,8,6,5],f32>, %arg2: !torch.vtensor<[1,1,2,2],f32>, %arg3: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} { + // CHECK: %[[cstOne:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[mask:.*]] = torch.aten.full %[[sizeList:.*]], %[[cstOne]] + // CHECK-SAME: -> !torch.vtensor<[1,4,6,5],f32> + // CHECK: torch.torchvision.deform_conv2d %arg0, %arg2, %arg1, %[[mask]], %arg3 + // CHECK-SAME: : !torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1,4,6,5],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[1,1,6,5],f32> + %1 = torch.operator "onnx.DeformConv"(%arg0, %arg2, %arg1, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.offset_group = 1 : si64, torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> + return %1 : !torch.vtensor<[1,1,6,5],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_si8 func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> From d2bc70f18855e672f91942b19259a5938d6d3cf4 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:34:19 -0500 Subject: [PATCH 09/30] [TorchToLinalg][ONNX] Add Basic Determinant Support (#3481) This adds support for a few ops: - torch.linalg_det - torch._linalg_det (if the LU and pivot returns are unused) - onnx.Det An scf loop is used, since the row reduction algorithm applied here has some loop-carried dependencies. The current support being added here is very basic, and only works if no permutations are required during row reduction, and assumes the matrices are non-singular. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 48 ++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 10 + .../TorchToLinalg/TorchToLinalg.cpp | 4 +- .../TorchToLinalg/Uncategorized.cpp | 215 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 76 +++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 23 ++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 10 + .../build_tools/abstract_interp_lib_gen.py | 19 ++ .../build_tools/torch_ods_gen.py | 2 + .../test_suite/__init__.py | 1 + .../test_suite/linalg_algorithms.py | 51 +++++ 12 files changed, 459 insertions(+), 1 deletion(-) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4b2ba6defa027..be5bc56d7fe78 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8586,6 +8586,54 @@ def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [ }]; } +def Torch_AtenLinalgDetOp : Torch_Op<"aten.linalg_det", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_det : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgDetOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLinalgDetOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_Aten_LinalgDetOp : Torch_Op<"aten._linalg_det", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A + ); + let results = (outs + AnyTorchOptionalTensorType:$result, + AnyTorchOptionalTensorType:$LU, + AnyTorchOptionalTensorType:$pivots + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_LinalgDetOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 3); + } + void Aten_LinalgDetOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 3); + } + }]; +} + def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index c89452ad6cb35..446298e89b336 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1972,6 +1972,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( useMaskValue); return success(); }); + patterns.onOp( + "Det", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp(binder.op, + resultType, input); + return success(); + }); patterns.onOp( "DequantizeLinear", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 7f57744b4af5d..01b1d4b973b6a 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" @@ -42,6 +43,7 @@ class ConvertTorchToLinalg registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -51,7 +53,7 @@ class ConvertTorchToLinalg ConversionTarget target(*context); target.addLegalDialect< linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect, - math::MathDialect, sparse_tensor::SparseTensorDialect, + math::MathDialect, scf::SCFDialect, sparse_tensor::SparseTensorDialect, tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>(); target.addLegalOp(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 1330174699a53..5e5f860652019 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -2952,6 +2953,218 @@ class ConvertInterpolateOp } }; } // namespace + +namespace { +// This pattern row reduces a matrix, then returns the product of it's diagonal +// elements +class ConvertAtenLinalgDetOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenLinalgDetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value input = adaptor.getA(); + auto inputType = cast(input.getType()); + unsigned inputRank = inputType.getRank(); + auto elemTy = inputType.getElementType(); + bool isBatched = (inputRank == 3); + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstZeroF = getConstant(rewriter, loc, 0, elemTy); + // get some shapes + SmallVector inputShape(inputType.getShape()); + SmallVector sliceShape(inputShape); + sliceShape.pop_back(); + SmallVector diagShape({isBatched ? inputType.getShape()[0] : 1}); + auto sliceTy = RankedTensorType::get(sliceShape, elemTy); + auto diagTy = RankedTensorType::get(diagShape, elemTy); + // get some sizes + SmallVector inputSizes = getTensorSizes(rewriter, loc, input); + Value chDim = isBatched ? inputSizes[0] : cstOne; + Value matDim = inputSizes[inputRank - 1]; + Value matDimMinusOne = rewriter.create(loc, matDim, cstOne); + ArrayRef sliceSizes(inputSizes.begin(), inputSizes.end() - 1); + // initialize a tensor to store the diagonal elements found during row + // reduction + Value initDiags = rewriter.create( + loc, getAsOpFoldResult(sliceSizes), elemTy); + // loop over each pivot row in A. Get the diagonal, then reduce the + // subdiagonal Don't perform the loop on the last row since no further + // reduction is needed. + auto rowReductionLoop = rewriter.create( + loc, /*start=*/cstZero, /*end=*/matDimMinusOne, /*step=*/cstOne, + /*yeild_to=*/ValueRange{input, initDiags}, /*body_lambda=*/ + [&](OpBuilder &b, Location loc, Value row, ValueRange vals) { + // extract row i from input Tensor of shape CxNxN or shape + // NxN. + OpFoldResult cstOneFold = getAsOpFoldResult(cstOne); + OpFoldResult cstZeroFold = getAsOpFoldResult(cstZero); + SmallVector offsets(inputRank, cstZeroFold); + offsets[inputRank - 2] = row; + SmallVector strides(inputRank, cstOneFold); + auto sizes = getAsOpFoldResult(inputSizes); + sizes[inputRank - 2] = cstOneFold; + // offsets = [0, row, 0], sizes = [C, 1, N] -> pivot row + Value pivot = b.create( + loc, sliceTy, vals[0], offsets, sizes, strides); + // extract diagonal elements and insert them into vals[1] + offsets.back() = row; + sizes.back() = cstOneFold; + // offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row) + Value diag = b.create( + loc, diagTy, vals[0], offsets, sizes, strides); + SmallVector diagOffsets(inputRank - 1, cstZeroFold); + diagOffsets.back() = row; + SmallVector diagStrides(inputRank - 1, cstOneFold); + SmallVector diagSizes = getAsOpFoldResult(sliceSizes); + diagSizes.back() = cstOneFold; + // offsets = [0, row], sizes = [C, 1] insert to [C,N] + Value updatedDiags = b.create( + loc, diag, vals[1], diagOffsets, diagSizes, diagStrides); + // the subpivot matrix column size, as a Value, is matDim - row - + // cstOne. This can't be statically converted to an int64_t, since row + // is the loop index, so this is left as a dynamic dim. + SmallVector subPivotShape(inputType.getShape()); + subPivotShape[inputRank - 2] = ShapedType::kDynamic; + ArrayRef subDiagShape(subPivotShape.begin(), + subPivotShape.end() - 1); + auto subPivotTy = RankedTensorType::get(subPivotShape, elemTy); + auto subDiagTy = RankedTensorType::get(subDiagShape, elemTy); + Value rowPlusOne = b.create(loc, row, cstOne); + offsets[inputRank - 2] = getAsOpFoldResult(rowPlusOne); + sizes[inputRank - 2] = getAsOpFoldResult( + b.create(loc, matDim, rowPlusOne)); + // offsets = [0, row + 1, row], sizes = [C, N - row - 1, 1] -> A_j,row + // with j > row + Value subDiag = b.create( + loc, subDiagTy, vals[0], offsets, sizes, strides); + offsets.back() = cstZeroFold; + sizes.back() = getAsOpFoldResult(matDim); + // offsets = [0, row + 1, 0], sizes = [C, N - row - 1, N] -> elements + // below pivot row + Value subPivot = b.create( + loc, subPivotTy, vals[0], offsets, sizes, strides); + Value initResult = b.create(loc, sizes, elemTy); + // write a generic op to perform subpivot = subpivot - + // (subdiag/diag)*pivot + // d0 = batches, d1 = row, d2 = column -> pivot(d0,d2), diag(d0), + // subPivot(d0,d1,d2), subDiag(d0, d1); output(d0,d1,d2) + SmallVector allDims; + for (unsigned i = 0; i < inputRank; i++) + allDims.push_back(b.getAffineDimExpr(i)); + SmallVector rowIterator(1, allDims[0]); + SmallVector colIterator; + SmallVector batchIterator; + if (isBatched) { + rowIterator.push_back(allDims[1]); + colIterator.push_back(allDims[0]); + colIterator.push_back(allDims[2]); + batchIterator.push_back(allDims[0]); + } else { + colIterator.push_back(allDims[1]); + batchIterator.push_back(getAffineConstantExpr(0, context)); + } + SmallVector indexingMaps; + indexingMaps.push_back( + AffineMap::get(inputRank, 0, colIterator, context)); + indexingMaps.push_back( + AffineMap::get(inputRank, 0, batchIterator, context)); + indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank)); + indexingMaps.push_back( + AffineMap::get(inputRank, 0, rowIterator, context)); + indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank)); + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value reducedSubPivot = + b.create( + loc, subPivotTy, ValueRange{pivot, diag, subPivot, subDiag}, + initResult, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // for d0 in batches, d1 in subpivotrows, d2 in columns + // let i represent the pivot row index (scf loop index) + Value pivotd0d2 = args[0]; + Value diagd0 = args[1]; + Value subPivotd0d1d2 = args[2]; + Value subDiagd0d1 = args[3]; + // coeff = A_d1,i / A_i,i + Value coeff = + b.create(loc, subDiagd0d1, diagd0); + auto cmp = b.create( + loc, arith::CmpFPredicate::ONE, diagd0, cstZeroF); + b.create( + loc, cmp, + b.getStringAttr( + "unimplemented: determinants requiring " + "permutations and singular matrices")); + // coeff*A_i,d2 + Value scaledPivotValue = + b.create(loc, coeff, pivotd0d2); + // result = A_d1,d2 - (A_d1,i/A_i,i)*A_i,d2 + // so that when d2 = i, A_d1,i - (A_d1,i/A_i,i) * A_i,i = 0 + Value result = b.create(loc, subPivotd0d1d2, + scaledPivotValue); + b.create(loc, result); + }) + .getResult(0); + Value rowReductionResult = b.create( + loc, reducedSubPivot, vals[0], offsets, sizes, strides); + b.create(loc, + ValueRange{rowReductionResult, updatedDiags}); + }); + Value allDiagsExceptLast = rowReductionLoop.getResult(1); + SmallVector offsets(inputRank, + getAsOpFoldResult(matDimMinusOne)); + SmallVector strides(inputRank, getAsOpFoldResult(cstOne)); + SmallVector sizes(inputRank, getAsOpFoldResult(cstOne)); + sizes[0] = getAsOpFoldResult(chDim); + if (isBatched) + offsets[0] = getAsOpFoldResult(cstZero); + Value lastDiag = rewriter.create( + loc, diagTy, rowReductionLoop.getResult(0), offsets, sizes, strides); + offsets.pop_back(); + strides.pop_back(); + sizes.pop_back(); + Value allDiags = rewriter.create( + loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides); + // linalg generic to do reduce prod for allDiags along back dim. + // the result of that generic will be the determinant + SmallVector indexingMaps; + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(inputRank - 1)); + AffineExpr resultExpr = isBatched ? rewriter.getAffineDimExpr(0) + : getAffineConstantExpr(0, context); + indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr)); + SmallVector iteratorTypes( + inputRank - 1, utils::IteratorType::parallel); + Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy, + getConstant(rewriter, loc, 1.0, elemTy)); + Value determinant = + rewriter + .create( + loc, initDet.getType(), ValueRange{allDiags}, initDet, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value prod = b.create(loc, args[0], args[1]); + b.create(loc, prod); + }) + .getResult(0); + Type newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (isBatched) { + rewriter.replaceOpWithNewOp(op, newResultType, + determinant); + return success(); + } + Value detVal = rewriter.create( + loc, determinant, SmallVector(1, cstZero)); + rewriter.replaceOpWithNewOp(op, newResultType, + ValueRange{detVal}); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3009,4 +3222,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e94d3bd7c9df6..6974636c0e86c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6485,6 +6485,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0: !torch.list) -> !torch.list {\n" +" %int-2 = torch.constant.int -2\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.list) {\n" +" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield %9 : !torch.list\n" +" } else {\n" +" %9 = torch.derefine %arg0 : !torch.list to !torch.any\n" +" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list\n" +" torch.prim.If.yield %10 : !torch.list\n" +" }\n" +" return %8 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._linalg_det\"(%arg0: !torch.list) -> !torch.tuple, list, list> {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int-1 = torch.constant.int -1\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = torch.aten.slice.t %arg0, %none, %int-1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %arg0, %1 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %2 : !torch.tuple, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._linalg_det\"(%arg0: !torch.tuple) -> !torch.tuple {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %1 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10986,6 +11048,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_det\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 04f505bea6793..7c2c29a6d720b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2619,6 +2619,28 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { }; } // namespace +namespace { + +class DecomposeAten_LinalgDetOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_LinalgDetOp op, + PatternRewriter &rewriter) const override { + SmallVector results = op.getResults(); + if (!results[1].use_empty() || !results[2].use_empty()) + return rewriter.notifyMatchFailure( + op, "unsupported: _linalg_det results: LU and pivot"); + Location loc = op.getLoc(); + Value input = op.getA(); + Value determinant = rewriter.create( + loc, results[0].getType(), input); + rewriter.replaceAllUsesWith(results[0], determinant); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -8701,6 +8723,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns); // More specific conv ops diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 0006a97f44d2d..21e2abb2474e9 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -404,6 +404,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 35a34e2b10688..a0d7616a6a957 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -559,6 +559,9 @@ "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", "DeformConv2D_basic", + "DeterminantBatchedModule_F32", + "DeterminantDynamicModule_F32", + "DeterminantModule_F32", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -2939,6 +2942,9 @@ "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", "DeformConv2D_basic", + "DeterminantBatchedModule_F32", + "DeterminantDynamicModule_F32", + "DeterminantModule_F32", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -3734,6 +3740,10 @@ "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", "DeformConv2D_basic", + "DeterminantModule_F32", + "DeterminantBatchedModule_F32", + "DeterminantDynamicModule_F32", + "DeterminantModule_F32", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1f70a42ce8eeb..0b356cc3412cf 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -223,6 +223,19 @@ def aten〇sign〡shape(self: List[int]) -> List[int]: def aten〇sgn〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇linalg_det〡shape(A: List[int]) -> List[int]: + assert len(A) == 2 or len(A) == 3 + assert A[-1] == A[-2] + if len(A) == 3: + return A[:1] + return upstream_shape_functions.zero_dim_tensor(A) + +def aten〇_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List[int]]: + return (aten〇linalg_det〡shape(A), A, A[:-1]) + +def aten〇_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]: + return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1]) + def aten〇detach〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2630,6 +2643,12 @@ def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes()})) +def aten〇linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = A_rank_dtype + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int: input_rank, input_dtype = input_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7c3f79ef44297..90d3e10546849 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -699,6 +699,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") + emit("aten::linalg_det : (Tensor) -> (Tensor)") + emit("aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 46d2909eb8ab0..03f8bc193be16 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -43,6 +43,7 @@ def register_all_tests(): from . import slice_like from . import nll_loss from . import index_select + from . import linalg_algorithms from . import arange from . import constant_alloc from . import threshold diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py new file mode 100644 index 0000000000000..0bb620591c407 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py @@ -0,0 +1,51 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class DeterminantModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.det(A) + + +@register_test_case(module_factory=lambda: DeterminantModule()) +def DeterminantModule_F32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.float32) + module.forward(A) + + +class DeterminantBatchedModule(torch.nn.Module): + @export + @annotate_args([None, [(3, 4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.det(A) + + +@register_test_case(module_factory=lambda: DeterminantBatchedModule()) +def DeterminantBatchedModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) + + +class DeterminantDynamicModule(torch.nn.Module): + @export + @annotate_args([None, [(-1, -1, -1), torch.float32, True]]) + def forward(self, A): + return torch.linalg.det(A) + + +@register_test_case(module_factory=lambda: DeterminantBatchedModule()) +def DeterminantDynamicModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) From e29191bd08753e342e7ada78612ba5cad483a6e0 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 26 Jun 2024 09:59:49 +0100 Subject: [PATCH 10/30] [LINALG] Broadcast `values` to shape of slize in `index_put` (#3487) The `index_put` operation, `input[indices] = values`, allows for the values to be any shape that is broadcastable to the slice `input[indices]`. This commit adds broadcasting support to the Linalg lowering of `IndexPutHackedTwinOp`. Fixes: #3465 --- .../TorchToTMTensor/TorchToTMTensor.cpp | 63 +++++++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 3 +- .../torch_mlir_e2e_test/test_suite/scatter.py | 33 ++++++++++ 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 9d0a764c18522..b6bd3b8b6a362 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -541,19 +541,9 @@ class ConvertAtenBincountOp : public OpConversionPattern { namespace { -Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, - OpBuilder b) { - llvm::SmallVector indices(indicesRef); - // Declare commonly used constants up front: - Value torchCstZero = - b.create(loc, b.getI64IntegerAttr(0)); - Value torchCstOne = - b.create(loc, b.getI64IntegerAttr(1)); - Value torchCstNegOne = - b.create(loc, b.getI64IntegerAttr(-1)); - - // Determine the broadcast sizes and materialize missing implicit end - // dimensions: +// Determine the common broadcast shape of all the index tensors. +std::pair, llvm::SmallVector> +getBroadcastShape(Location loc, llvm::ArrayRef indices, OpBuilder b) { int64_t indicesRank = 0; for (auto index : indices) { auto indexTy = cast(index.getType()); @@ -567,6 +557,8 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, return std::max(dim0, dim1); }; + Value torchCstOne = + b.create(loc, b.getI64IntegerAttr(1)); llvm::SmallVector broadcastSizes(indicesRank, torchCstOne); llvm::SmallVector broadcastShape(indicesRank, 0); for (auto index : indices) { @@ -585,6 +577,21 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, broadcastShape[idx] = maxDim(size, broadcastShape[idx]); } } + return std::make_pair(broadcastSizes, broadcastShape); +} + +Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, + OpBuilder b) { + llvm::SmallVector indices(indicesRef); + // Declare commonly used constants up front: + Value torchCstZero = + b.create(loc, b.getI64IntegerAttr(0)); + Value torchCstOne = + b.create(loc, b.getI64IntegerAttr(1)); + Value torchCstNegOne = + b.create(loc, b.getI64IntegerAttr(-1)); + + auto [broadcastSizes, broadcastShape] = getBroadcastShape(loc, indicesRef, b); auto mulDim = [](int64_t dim0, int64_t dim1) { if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) @@ -733,6 +740,34 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, return b.create(loc, valuesTy, values, outDimsList); } +// Broadcast the `values` tensor to the slice size created by the list of index +// tensors. +static Value broadcastValuesToSliceSize(Location loc, Value input, Value values, + llvm::ArrayRef indices, + OpBuilder b) { + auto inputType = cast(input.getType()); + ArrayRef inputStaticShape = inputType.getSizes(); + auto valuesType = cast(values.getType()); + + // In the case where the input rank is greater than the number of index + // tensors, the remaining dimensions of the input are indexed in their + // entirety. Thus, we need to append the remaining dimensions to get the shape + // of the indexed slice. + auto [resultShape, resultStaticShape] = getBroadcastShape(loc, indices, b); + for (size_t i = indices.size(); i < inputStaticShape.size(); i++) { + Value dim = b.create(loc, b.getI64IntegerAttr(i)); + resultShape.push_back(b.create(loc, input, dim)); + resultStaticShape.push_back(inputStaticShape[i]); + } + + auto resultType = b.getType( + resultStaticShape, valuesType.getOptionalDtype()); + Value broadcastShapeList = b.create( + loc, Torch::ListType::get(b.getType()), resultShape); + return b.create(loc, resultType, values, + broadcastShapeList); +} + class ConvertAtenIndexPutHackedTwinOp : public OpConversionPattern { public: @@ -780,6 +815,8 @@ class ConvertAtenIndexPutHackedTwinOp if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); + values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList, + rewriter); // Filter to available indices and get the indicesMap: SmallVector indicesList; SmallVector indicesMap; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a0d7616a6a957..8db4414bbb204 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1494,7 +1494,7 @@ "RenormModuleFloat32_basic", } -STABLEHLO_CRASHING_SET = set() +STABLEHLO_CRASHING_SET = {"IndexPutWithNoneAndBroadcastModule_basic"} # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. @@ -2427,6 +2427,7 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", "IntFloatModule_basic", "IntImplicitModule_basic", "IouOfModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 8f7ea32910d67..ba44dc076904c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1269,3 +1269,36 @@ def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): tu.randint(7, high=5), tu.rand(2, 3, 6, 7), ) + + +# ============================================================================== + + +class IndexPutWithNoneAndBroadcastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 4, 5], torch.float32, True), + ([6, 1], torch.int64, True), + ([7], torch.int64, True), + ([1, 6, 7], torch.float32, True), + ] + ) + def forward(self, input, index1, index2, value): + return torch.ops.aten.index_put( + input, (None, None, index1, index2), value, accumulate=True + ) + + +@register_test_case(module_factory=lambda: IndexPutWithNoneAndBroadcastModule()) +def IndexPutWithNoneAndBroadcastModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 3, 4, 5), + tu.randint(6, 1, high=4), + tu.randint(7, high=5), + tu.rand(1, 6, 7), # broadcasted to (2, 3, 6, 7) + ) From 6eebe61bfe8b0b774d178b654bd022bf561ed865 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Wed, 26 Jun 2024 09:10:14 -0700 Subject: [PATCH 11/30] [Tosa] Conversion from torch.__interpolate to tosa.resize() (#3488) Signed-off-by: Suraj Sudhir --- .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 32 ++- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 189 ++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 56 ++++++ 3 files changed, 276 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index d5db519bef17e..a5a58064489a9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -208,6 +208,37 @@ m_TorchListOfOptionalConstantInts( return detail::torch_list_of_optional_constant_ints_op_binder(bind_values); } +namespace detail { +/// Matches the constant floats stored in a `torch.prim.ListConstruct`. +struct torch_list_of_constant_floats_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_floats_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + double num; + if (matchPattern(value, m_TorchConstantFloat(&num))) + bind_values.push_back(num); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant integers stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_floats_op_binder +m_TorchListOfConstantFloats(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_floats_op_binder(bind_values); +} + namespace detail { /// Matches the constant bools stored in a `torch.ListConstruct`. struct torch_list_of_constant_bools_op_binder { @@ -238,7 +269,6 @@ inline detail::torch_list_of_constant_bools_op_binder m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } - namespace detail { /// Matches the constant strs stored in a `torch.ListConstruct`. struct torch_list_of_constant_strs_op_binder { diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 524dc953e866b..385c5e6ec35fb 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include #include using namespace mlir; @@ -5088,6 +5089,193 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult +ConvertAtenOp::matchAndRewrite( + Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Converts torch.aten.__interpolate.size_list_scale_list to tosa.resize + auto input = adaptor.getInput(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); + auto inputRank = inputTy.getRank(); + if (inputRank != 4) + return rewriter.notifyMatchFailure(op, + "TOSA resize() takes rank==4 tensors."); + + auto inputShape = inputTy.getShape(); + auto inputElemTy = inputTy.getElementType(); + // TOSA works in NHWC. Perform the necessary transformations. + std::optional nchwToNhwcTransposeConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/{0, 2, 3, 1}, + /*shape=*/{static_cast(4)}); + SmallVector transposedInputShape( + {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); + auto transposedInputTy = RankedTensorType::get( + makeShapeLLVMCompatible(transposedInputShape), inputElemTy); + auto transposedInput = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(transposedInputTy), + input, nchwToNhwcTransposeConst.value()) + .getResult(); + + auto inputHeight = transposedInputShape[1]; + auto inputWidth = transposedInputShape[2]; + + int outputHeight, outputWidth; + if (!isa(op.getScaleFactor().getType())) { + SmallVector scaleFactor; + if (!matchPattern(op.getScaleFactor(), + m_TorchListOfConstantFloats(scaleFactor))) + return rewriter.notifyMatchFailure( + op, "non-const scale_factor parameter unsupported"); + + outputHeight = inputHeight * scaleFactor[0]; + outputWidth = inputWidth * scaleFactor[1]; + + } else { + if (!isa(op.getSize().getType())) + return rewriter.notifyMatchFailure( + op, "Scale factor and size are both absent!"); + + SmallVector size; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) + return rewriter.notifyMatchFailure( + op, "non-const size parameter unsupported"); + outputHeight = size[0]; + outputWidth = size[1]; + } + + std::string pyMode; + if (!matchPattern(op.getMode(), m_TorchConstantStr(pyMode))) + return rewriter.notifyMatchFailure(op, + "non-const mode parameter unsupported"); + + // All torch modes listed in + // https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + if (pyMode != "bilinear" && pyMode != "nearest") + return rewriter.notifyMatchFailure( + op, "Only nearest and bilinear interpolation modes supported"); + + std::string mode; + if (pyMode == "bilinear") { + mode = "BILINEAR"; + } else { + mode = "NEAREST_NEIGHBOR"; + } + + bool alignCorners; + if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners))) + return rewriter.notifyMatchFailure( + op, "non-const align_corners parameter unsupported"); + + bool recomputeScaleFactor; + if (isa(op.getRecomputeScaleFactor().getType())) + recomputeScaleFactor = false; + else if (!matchPattern(op.getRecomputeScaleFactor(), + m_TorchConstantBool(&recomputeScaleFactor))) + return rewriter.notifyMatchFailure( + op, "non-const recompute_scale_factor parameter unsupported"); + if (recomputeScaleFactor) + return rewriter.notifyMatchFailure( + op, "Application of recompute_scale_factor not yet supported"); + + bool antialias; + if (!matchPattern(op.getAntialias(), m_TorchConstantBool(&antialias))) + return rewriter.notifyMatchFailure( + op, "non-const antialias parameter unsupported"); + if (antialias) + return rewriter.notifyMatchFailure( + op, "Application of antialias not yet supported"); + + SmallVector transposedResizedOpShape( + {inputShape[0], outputHeight, outputWidth, inputShape[1]}); + auto transposedResizedOpTy = RankedTensorType::get( + makeShapeLLVMCompatible(transposedResizedOpShape), inputElemTy); + + // Formatting snake_case to match TOSA spec names for readability + int scale_y_n, scale_y_d, offset_y, border_y; + int scale_x_n, scale_x_d, offset_x, border_x; + + // Align corners sets the scaling ratio to (OH - 1)/(IH - 1) + // rather than OH / IH. Similarly for width. + auto normalize = [&](int input, int output, int &n, int &d, int &offset, + int &border) { + // Dimension is length 1, we are just sampling from one value. + if (input == 1) { + n = output; + d = 1; + offset = 0; + border = output - 1; + return; + } + + // Apply if aligned and capable to be aligned. + bool apply_aligned = alignCorners && (output > 1); + n = apply_aligned ? (output - 1) : output; + d = apply_aligned ? (input - 1) : input; + + // Simplify the scalers, make sure they are even values. + int gcd = std::gcd(n, d); + n = 2 * n / gcd; + d = 2 * d / gcd; + + offset = 0; + + // If nearest neighbours we need to guarantee we round up. + if (mode == "NEAREST_NEIGHBOR" && alignCorners) { + offset += n / 2; + } + + // TBD: impact of antialias parameter here ? + + // We can compute this directly based on previous values. + border = d * (output - 1) - n * (input - 1) + offset; + }; + + normalize(inputHeight, outputHeight, scale_y_n, scale_y_d, offset_y, + border_y); + normalize(inputWidth, outputWidth, scale_x_n, scale_x_d, offset_x, border_x); + + DenseI64ArrayAttr scale = rewriter.getDenseI64ArrayAttr( + {scale_y_n, scale_y_d, scale_x_n, scale_x_d}); + DenseI64ArrayAttr offset = + rewriter.getDenseI64ArrayAttr({offset_y, offset_x}); + DenseI64ArrayAttr border = + rewriter.getDenseI64ArrayAttr({border_y, border_x}); + StringAttr modeAttr = rewriter.getStringAttr(mode); + + auto resizeOpResult = + rewriter + .create(op->getLoc(), transposedResizedOpTy, + transposedInput, scale, offset, border, + modeAttr) + .getResult(); + + auto resultType = + cast(typeConverter->convertType(op.getType())); + std::optional nhwcToNchwTransposeConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/{0, 3, 1, 2}, + /*shape=*/{static_cast(4)}); + // SmallVector transposedOutputShape( + // {transposedResizedOpShape[0], transposedResizedOpShape[3], + // transposedResizedOpShape[1], transposedResizedOpShape[2]}); + // auto transposedOutputType = RankedTensorType::get( + // makeShapeLLVMCompatible(transposedOutputShape), inputElemTy); + rewriter + .replaceOpWithNewOp( + op, getTypeConverter()->convertType(resultType), resizeOpResult, + nhwcToNchwTransposeConst.value()) + .getResult(); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -5340,6 +5528,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 4c0dc01938767..35007f2a2a388 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1302,3 +1302,59 @@ func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5] %0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1> return %0 : !torch.vtensor<[5,5],i1> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.__interpolate.size_list_scale_list.bilinear( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,16,135,240],f32> -> tensor<1x16x135x240xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.str "bilinear" +// CHECK: %[[VAL_5:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.float, !torch.float) -> !torch.list +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] : (tensor<1x16x135x240xf32>, tensor<4xi32>) -> tensor<1x135x240x16xf32> +// CHECK: %[[VAL_9:.*]] = tosa.resize %[[VAL_8]] {border = array, mode = "BILINEAR", offset = array, scale = array} : (tensor<1x135x240x16xf32>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_9]], %[[VAL_10]] : (tensor<1x270x480x16xf32>, tensor<4xi32>) -> tensor<1x16x270x480xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,16,270,480],f32> +// CHECK: } +func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %str = torch.constant.str "bilinear" + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.prim.ListConstruct %float2.000000e00, %float2.000000e00 : (!torch.float, !torch.float) -> !torch.list + %1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32> + return %1 : !torch.vtensor<[1,16,270,480],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.__interpolate.size_list_scale_list.nearest( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,16,135,240],f32> -> tensor<1x16x135x240xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.str "nearest" +// CHECK: %[[VAL_5:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.float, !torch.float) -> !torch.list +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] : (tensor<1x16x135x240xf32>, tensor<4xi32>) -> tensor<1x135x240x16xf32> +// CHECK: %[[VAL_9:.*]] = tosa.resize %[[VAL_8]] {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} : (tensor<1x135x240x16xf32>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_9]], %[[VAL_10]] : (tensor<1x270x480x16xf32>, tensor<4xi32>) -> tensor<1x16x270x480xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,16,270,480],f32> +// CHECK: } +func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %str = torch.constant.str "nearest" + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.prim.ListConstruct %float2.000000e00, %float2.000000e00 : (!torch.float, !torch.float) -> !torch.list + %1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32> + return %1 : !torch.vtensor<[1,16,270,480],f32> +} From 6678e1a2560e2630b7d3839dd44ce3b0b5c81b55 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 27 Jun 2024 08:43:10 +0200 Subject: [PATCH 12/30] TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) Before this PR, a statically shaped aten.convolution would generate dynamically shaped linalg IR, and even `-canonicalize` would not be able to fold it back into static shapes. This PR ensure that shape calculations are folded on construction to directly generate statically shaped linalg IR. We achieve that by ensuring that `arith` ops involved in computing shapes are created via `createOrFold`, so that later uses of `getAsOpFoldResult` see constants instead of those ops. For example ``` module { func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>, %arg1: !torch.vtensor<[336,168,3,3],f32>, %arg2: !torch.vtensor<[336],f32>) -> !torch.vtensor<[32,336,56,56],f32> { %false = torch.constant.bool false %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list %2 = torch.prim.ListConstruct : () -> !torch.list %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[32,336,56,56],f32> return %3 : !torch.vtensor<[32,336,56,56],f32> } } ``` would result in ``` [...] %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] { ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor<32x336x112x112xf32> to tensor [...] %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%expanded, %expanded_37 : tensor, tensor<2x168x168x3x3xf32>) outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32> [...] ``` and with this PR all shapes are static. --- lib/Conversion/TorchToLinalg/Linear.cpp | 2 +- lib/Conversion/TorchToLinalg/Utils.cpp | 32 +++++------ lib/Conversion/Utils/Utils.cpp | 4 +- .../Transforms/BackendTypeConversion.cpp | 2 +- .../Conversion/TorchToLinalg/elementwise.mlir | 2 +- test/Conversion/TorchToLinalg/pooling.mlir | 22 ++++---- .../Conversion/TorchToLinalg/view_strict.mlir | 15 +++--- test/Conversion/TorchToSCF/basic.mlir | 10 ++-- .../TorchToStablehlo/elementwise.mlir | 36 +++++-------- test/Conversion/TorchToStablehlo/linear.mlir | 27 +++++----- .../TorchToStablehlo/view_like.mlir | 54 +++++++------------ 11 files changed, 85 insertions(+), 121 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index c72db61c42fcf..8e55707f299cf 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -860,7 +860,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { - return rewriter.create(loc, intType, v); + return rewriter.createOrFold(loc, intType, v); }; SmallVector paddingIntValues; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index c2658f35cce3e..46b51558f13d6 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -82,16 +82,10 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( pad < paddingIncludingUnchanged.end(); pad++) *pad = castIntToIndex(b, loc, *pad); - Type elementType = cast(input.getType()).getElementType(); - // TODO: audit possibility of sparsity on this tensor - Type inputType = - RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( - SmallVector(inRank, kUnknownSize))), - elementType); - SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); - return b.create(loc, inputType, input, /*low=*/paddingValues, + + return b.create(loc, Type{}, input, /*low=*/paddingValues, /*high=*/paddingValues, pad); } @@ -103,25 +97,25 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); - Value doublePadding = b.create(loc, paddingInt, c2); + Value doublePadding = b.createOrFold(loc, paddingInt, c2); // in + 2 * padding - Value inAddDoublePadding = - b.create(loc, castIndexToInt64(b, loc, in), doublePadding); + Value inAddDoublePadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), doublePadding); // dilation * (kernelSize - 1) - Value kernelSizeSub1 = b.create(loc, kernelSizeInt, c1); + Value kernelSizeSub1 = b.createOrFold(loc, kernelSizeInt, c1); Value dilationTimesKernelSize = - b.create(loc, dilationInt, kernelSizeSub1); + b.createOrFold(loc, dilationInt, kernelSizeSub1); - Value temp = - b.create(loc, inAddDoublePadding, dilationTimesKernelSize); - Value dividend = b.create(loc, temp, c1); + Value temp = b.createOrFold(loc, inAddDoublePadding, + dilationTimesKernelSize); + Value dividend = b.createOrFold(loc, temp, c1); Value division; if (ceilMode) - division = b.create(loc, dividend, strideInt); + division = b.createOrFold(loc, dividend, strideInt); else - division = b.create(loc, dividend, strideInt); - Value out = b.create(loc, division, c1); + division = b.createOrFold(loc, dividend, strideInt); + Value out = b.createOrFold(loc, division, c1); return castIntToIndex(b, loc, out); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 703bd2049f695..4af9709fdfd79 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -140,12 +140,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value castIntToIndex(OpBuilder &b, Location loc, Value v) { assert(isa(v.getType()) && "must be called with integer type"); - return b.create(loc, b.getIndexType(), v); + return b.createOrFold(loc, b.getIndexType(), v); } Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) { assert(isa(idx.getType()) && "must be called with integer type"); - return b.create(loc, b.getI64Type(), idx); + return b.createOrFold(loc, b.getI64Type(), idx); } SmallVector diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index c4f22715ab341..0f2533e063f0d 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -94,7 +94,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, if (!isa(inputs[0].getType())) return std::nullopt; assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); + return builder.createOrFold(loc, inputs[0]); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index 85be9f754d338..aa2be74f5d7e1 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -67,7 +67,7 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[C1:.*]] = torch.constant.int 1 -// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]] +// CHECK: %[[BUILTIN_C1:.*]] = arith.constant 1 : i64 // CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[ALPHA:.*]] = arith.sitofp %[[BUILTIN_C1]] : i64 to f32 diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 494f603c296eb..558c50c4f08f1 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -7,13 +7,13 @@ func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 %false = torch.constant.bool false - // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 + // CHECK: %[[C1:.*]] = arith.constant 1 : i64 // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3] // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index - // CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor - // CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + // CHECK: %[[T1:.*]] = arith.constant 1 : index + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32> + // CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor<1xf32>) outs(%[[OUT]] : tensor) -> tensor %kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list @@ -33,15 +33,15 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt %int7 = torch.constant.int 7 %int8 = torch.constant.int 8 %false = torch.constant.bool false - // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 - // CHECK: %[[C2:.*]] = torch_c.to_i64 %int2 + // CHECK: %[[C1:.*]] = arith.constant 1 : i64 + // CHECK: %[[C2:.*]] = arith.constant 2 : i64 // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index - // CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index - // CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]], %[[T2]]) : tensor - // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + // CHECK: %[[T1:.*]] = arith.constant 1 : index + // CHECK: %[[T2:.*]] = arith.constant 2 : index + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32> + // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor<1x2xf32>) outs(%[[OUT]] : tensor) -> tensor %kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list @@ -88,7 +88,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } : tensor to tensor // CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { + // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor<8x8x8xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32): // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 diff --git a/test/Conversion/TorchToLinalg/view_strict.mlir b/test/Conversion/TorchToLinalg/view_strict.mlir index 8be9a2f9fb5a3..a900fbb069276 100644 --- a/test/Conversion/TorchToLinalg/view_strict.mlir +++ b/test/Conversion/TorchToLinalg/view_strict.mlir @@ -7,10 +7,8 @@ // CHECK-LABEL: func.func @torch.aten.view$twotothree // CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> -// CHECK: %[[T3:.*]] = torch.constant.int 3 -// CHECK: %[[T2:.*]] = torch.constant.int 2 -// CHECK: %[[N2:.*]] = torch_c.to_i64 %[[T2]] -// CHECK: %[[N3:.*]] = torch_c.to_i64 %[[T3]] +// CHECK: %[[N2:.*]] = arith.constant 2 : i64 +// CHECK: %[[N3:.*]] = arith.constant 3 : i64 // CHECK: %[[ELEMENTS:.*]] = tensor.from_elements %[[N2]], %[[N3]] : tensor<2xi64> // CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[ARG0]](%[[ELEMENTS]]) : (tensor<3x2xf32>, tensor<2xi64>) -> tensor<2x3xf32> func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> @@ -112,13 +110,12 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // reshape. Someday, this should generate flatten/unflatten. // CHECK-LABEL: func.func @torch.aten$dynamicValOutput // CHECK: %[[SELF:.*]] = torch_c.to_builtin_tensor %arg0 -// CHECK: %[[CONSTANT1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[PROD1:.*]] = arith.constant 1 // CHECK-DAG: %[[ARG1_CVT:.*]] = torch_c.to_i64 %arg1 // CHECK-DAG: %[[PROD2:.*]] = arith.muli %[[PROD1]], %[[ARG1_CVT]] -// CHECK-DAG: %[[ONEI64:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK-DAG: %[[ONEI64:.*]] = arith.constant 1 : i64 // CHECK-DAG: %[[PROD3:.*]] = arith.muli %[[PROD2]], %[[ONEI64]] -// CHECK-DAG: %[[ONEI64_0:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK-DAG: %[[ONEI64_0:.*]] = arith.constant 1 : i64 // CHECK-DAG: %[[PROD4:.*]] = arith.muli %[[PROD3]], %[[ONEI64_0]] // CHECK-DAG: %[[INDEX0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[DIM0_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX0]] : tensor @@ -134,8 +131,8 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // CHECK-DAG: %[[KNOWN2:.*]] = arith.muli %[[KNOWN1]], %[[DIM2]] : i64 // CHECK-DAG: %[[DIMINFER:.*]] = arith.divui %[[KNOWN2]], %[[PROD4]] : i64 // CHECK: %[[DIM0:.*]] = torch_c.to_i64 %arg1 -// CHECK: %[[DIM1:.*]] = torch_c.to_i64 %[[CONSTANT1]] -// CHECK: %[[DIM3:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK: %[[DIM1:.*]] = arith.constant 1 : i64 +// CHECK: %[[DIM3:.*]] = arith.constant 1 : i64 // CHECK: %[[OUTPUT_DIMS:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]], %[[DIMINFER]], %[[DIM3]] : tensor<4xi64> // CHECK: tensor.reshape %[[SELF]](%[[OUTPUT_DIMS]]) : (tensor, tensor<4xi64>) -> tensor // diff --git a/test/Conversion/TorchToSCF/basic.mlir b/test/Conversion/TorchToSCF/basic.mlir index aa04c6d72a40c..dd64e99b8c240 100644 --- a/test/Conversion/TorchToSCF/basic.mlir +++ b/test/Conversion/TorchToSCF/basic.mlir @@ -4,9 +4,9 @@ // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { // CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %[[VAL_2]] +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : i64 // CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_1]] -> (i64) { // CHECK: scf.yield %[[VAL_3]] : i64 // CHECK: } else { @@ -31,11 +31,11 @@ func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { // CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_1]] // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_6:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i64 // CHECK: %[[VAL_8:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_8]] +// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64 // CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_2]] -> (i64) { // CHECK: %[[VAL_11:.*]] = scf.if %[[VAL_3]] -> (i64) { // CHECK: scf.yield %[[VAL_5]] : i64 diff --git a/test/Conversion/TorchToStablehlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir index 6403db6f2bcc7..104f6e0d87616 100644 --- a/test/Conversion/TorchToStablehlo/elementwise.mlir +++ b/test/Conversion/TorchToStablehlo/elementwise.mlir @@ -103,8 +103,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK-LABEL: func.func @torch.aten.addscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -124,10 +123,8 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.addscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -167,8 +164,7 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -204,8 +200,7 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK-LABEL: func.func @torch.aten.subscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -225,8 +220,7 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.rsubscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -246,10 +240,8 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.subscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -289,8 +281,7 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -326,8 +317,7 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK-LABEL: func.func @torch.aten.mulscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -359,8 +349,7 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.divscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -392,8 +381,7 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.gt.scalar( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] +// CHECK: %[[T1:.*]] = arith.constant 3 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index a333c93e9dfdb..db61dc262d026 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -278,7 +278,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK: %[[T_5:.*]] = torch.constant.int 1 // CHECK: %[[T_6:.*]] = torch.constant.int 4 // CHECK: %[[T_7:.*]] = torch.constant.int 3 -// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]] +// CHECK: %[[T_8:.*]] = arith.constant 3 : i64 // CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list @@ -314,8 +314,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %int2 = torch.constant.int 2 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int4 = torch.constant.int 4 -// CHECK: %int3 = torch.constant.int 3 -// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3 +// CHECK: %[[T_3:.*]] = arith.constant 3 : i64 // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -357,7 +356,7 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> @@ -388,7 +387,7 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -423,7 +422,7 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_2:.*]] = arith.constant 1 : i64 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -459,12 +458,12 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2 +// CHECK: %[[T_2:.*]] = arith.constant 2 : i64 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32> -// CHECK: %[[T_7:.*]] = stablehlo.reverse %6, dims = [0, 1] : tensor<3x3x2x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32> // CHECK: %c0 = arith.constant 0 : index // CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32> // CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64 @@ -477,14 +476,14 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %c3 = arith.constant 3 : index // CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32> // CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64 -// CHECK: %c2_i64 = arith.constant 2 : i64 -// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %c2_i64 : i64 -// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %c2_i64 : i64 -// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %c2_i64, %[[T_12]] : tensor<5xi64> +// CHECK: %[[C2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %[[C2]] : i64 +// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %[[C2]] : i64 +// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %[[C2]], %[[T_12]] : tensor<5xi64> // CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32> // CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> -// CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> -// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> +// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> +// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> // CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index 3b01690364bdd..f956c13cff184 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -3,12 +3,9 @@ // CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT10:.*]] = torch.constant.int 10 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT10]] +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T3:.*]] = arith.constant 10 : i64 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -42,7 +39,7 @@ // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -58,12 +55,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T3:.*]] = arith.constant 9223372036854775807 : i64 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -97,7 +91,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32> func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { @@ -113,12 +107,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK-LABEL: func.func @torch.aten.slice.last$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 1 : i64 +// CHECK: %[[T3:.*]] = arith.constant -1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -152,7 +143,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { @@ -168,12 +159,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 1 : i64 +// CHECK: %[[T3:.*]] = arith.constant -1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -207,7 +195,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32> func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { @@ -224,8 +212,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 2 : i64 // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor @@ -247,7 +234,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -264,8 +251,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 2 : i64 // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> @@ -287,7 +273,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> // CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32> func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { From 39d133200862a2b57fcb0c5c1b017b3239cf130c Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Thu, 27 Jun 2024 17:08:44 +0530 Subject: [PATCH 13/30] add onnx loop support (#3408) - Adds limited support for lowering onnx.Loop to primLoopOp - lower in the pipeline`torch-to-scf` there is a check to see if loop is for like. A primLoopOp is for like when the input condition is a `trueBoolConstant`. To adapt the onnx to torch lowering to take advantage of it, the implementation checks for specific op patterns in the loodBody region and decides if loop is for like and uses the right input condition op. - to adapt the onnxLoopBody to torchLoopBody, we need to adapt the input block arguments and set the correct output condition variable in the loop body. - scanOutput variables are currently not supported. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 10 ++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 155 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 62 +++++++ 3 files changed, 226 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 90871110d20c1..90d05e8c8bb0a 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -209,6 +209,16 @@ struct OpBinder { return success(); } + ParseResult tensorOperandTypes(llvm::SmallVector &typeList) { + for (auto operand : op->getOperands()) { + auto t = toValidTensorType(operand.getType()); + if (!t) + return failure(); + typeList.push_back(t); + } + return success(); + } + // The importer imports Onnx.GraphProto attributes as regions attached to the // op. ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0c7955b1e4938..40aaa6ac47e2c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -259,6 +259,159 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Loop", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Get all operands (maxTripCount, cond, ....inits....) + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || operands.size() == 0 || + binder.getNumOperands() < 2) { + return rewriter.notifyMatchFailure(binder.op, + "Failed to get required operands"); + } + + llvm::SmallVector operandTypeVec; + if (binder.tensorOperandTypes(operandTypeVec) || + operandTypeVec.size() == 0) { + return rewriter.notifyMatchFailure(binder.op, + "Failed to get operandTypes"); + } + + Region *loopBodyIn; + if (binder.getRegionAtIndex(loopBodyIn, 0)) { + return rewriter.notifyMatchFailure(binder.op, + "Failed getting LoopBody Region"); + } + + // MaxTripCount - tensor int64 scalar (or empty) + Value maxTripCountTensor = operands[0]; + auto maxTripCountInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + maxTripCountTensor); + + // Condition - tensor bool scalar (or empty) + Value conditionTensor = operands[1]; + auto conditionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + conditionTensor); + auto conditionBool = rewriter.create( + binder.getLoc(), rewriter.getType(), conditionInt); + // To be used for "for like" loop case + auto constBoolTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + + // Others (if present) - variadic (can be tensors and scalar values) + if (binder.getNumOperands() > 2) { + operandTypeVec.erase(operandTypeVec.begin(), + operandTypeVec.begin() + 2); + operands.erase(operands.begin(), operands.begin() + 2); + } + + auto getOpName = [](Operation *op) -> std::string { + std::string name = op->getName().getStringRef().str(); + if (name != "torch.operator") + return name; + // for unconverted onnx ops + return mlir::dyn_cast(op->getAttr("name")) + .getValue() + .str(); + }; + + // PrimLoop Op expectes inputCondition to be boolConstantTrue + // to decide if the loopOp is `forlike`. Use loopIsForLike to + // ensure appropriate inputCondition is set + // Case 1 : loopCondInp -> identity -> terminator(loopCondOut) + bool loopIsForLike = false; + auto case1ForLike = [&getOpName](Region *loopBody) -> bool { + Value onnxLoopBodyCondIn = loopBody->front().getArgument(1); + if (!onnxLoopBodyCondIn.hasOneUse()) + return false; + Operation *inpCondUser = *onnxLoopBodyCondIn.getUsers().begin(); + if (getOpName(inpCondUser) != "onnx.Identity") { + return false; + } + if (!inpCondUser->hasOneUse() || + getOpName(*(inpCondUser->getUsers().begin())) != + "torch.operator_terminator") + return false; + return true; + }; + loopIsForLike = case1ForLike(loopBodyIn); + + Value loopInitCondition = + loopIsForLike ? constBoolTrue : conditionBool.getResult(); + auto loc = binder.getLoc(); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + auto loop = b.create( + TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition, + ValueRange(operands)); + + rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(), + loop.getRegion().begin()); + + // primLoopOp loopBody expects torch.int as first arg + // insert torch.int arg in loop body, convert to tensor, + // replace all uses of old arg, delete old arg. + auto loopVarArg = loop.getRegion().front().getArgument(0); + // insert new Arg + loop.getRegion().front().insertArgument( + 0U, rewriter.getType(), binder.getLoc()); + auto newLoopVarArg = loop.getRegion().front().getArgument(0); + + // convert int arg to tensor of original Type + rewriter.setInsertionPointToStart(&loop.getRegion().front()); + Value loopVarVal = BlockArgument::Value(loopVarArg); + auto newTensor = rewriter.create( + loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(), + newLoopVarArg); + + loopVarArg.replaceAllUsesWith(newTensor); + loop.getRegion().eraseArgument(1); + + // primLoopOp loopBody has no condition arg + auto condArg = loop.getRegion().front().getArgument(1); + if (!condArg.use_empty()) + condArg.replaceAllUsesWith(conditionTensor); + + // replace terminator + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = loop.getRegion().front().getTerminator(); + rewriter.setInsertionPoint(terminator); + + // results - n loop carried dependencies and k scan outputs + // Fail when there are scanOutputs in onnxLoop (K>0); + // unsupported for now + if (terminator->getNumOperands() != + loop.getRegion().getNumArguments() - 1) { + return rewriter.notifyMatchFailure( + binder.op, "scanOutputs in loop body unsupported"); + } + + // Get remaining operands from onnxLoopBody's terminator Op + // these are all the loop carried dependencies in the loop body + auto terminatorOperands = terminator->getOperands(); + llvm::SmallVector remTerminatorOperands( + terminatorOperands.begin() + 1, terminatorOperands.end()); + Value terminatorCond; + if (loopIsForLike) { + terminatorCond = constBoolTrue; + } else { + // Only use when loop is not forlike + Value terminatorCondTensor = terminatorOperands[0]; + auto terminatorCondInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + terminatorCondTensor); + auto terminatorCondBool = rewriter.create( + binder.getLoc(), rewriter.getType(), + terminatorCondInt); + terminatorCond = terminatorCondBool.getResult(); + } + rewriter.replaceOpWithNewOp( + terminator, terminatorCond, remTerminatorOperands); + + loop.getRegion().eraseArgument(1); + rewriter.replaceOp(binder.op, loop); + return success(); + }); patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander); patterns.onOp( "LogSoftmax", 13, @@ -2197,7 +2350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp( - "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Identity", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c60ac654fb6b1..77991912c5e8e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1652,3 +1652,65 @@ func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list>) -> !torch.vtensor<[],i1> return %0 : !torch.vtensor<[],i1> } + +// ----- + +// CHECK-LABEL: func.func @test_loop_forlike +func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],i1>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "loop_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[MAX_TRIP_COUNT_INP:.*]]: !torch.vtensor<[],si64>, + // CHECK-SAME: %[[CONDITION_INP:.*]]: !torch.vtensor<[],i1>, + // CHECK-SAME: %[[LCD_1:.*]]: !torch.vtensor<[1],f32> + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[MAX_TRIP_COUNT_INT:.*]] = torch.aten.item %[[MAX_TRIP_COUNT_INP]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[CONDITION_INT:.*]] = torch.aten.item %[[CONDITION_INP]] : !torch.vtensor<[],i1> -> !torch.int + // CHECK: %[[CONDITION_BOOL:.*]] = torch.aten.Bool.int %[[CONDITION_INT]] : !torch.int -> !torch.bool + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[MAX_TRIP_COUNT_INT]], %[[TRUE]], init(%[[LCD_1]]) { + // CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[LCD_1_BODY:.*]]: !torch.vtensor<[1],f32>): + // CHECK: %[[ITER_NUM_T:.*]] = torch.prim.NumToTensor.Scalar %[[ITER_NUM]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[NONE_1:.*]] = torch.constant.none + // CHECK: %[[CLONE_INP_COND:.*]] = torch.aten.clone %[[CONDITION_INP]], %[[NONE_1]] : !torch.vtensor<[],i1>, !torch.none -> !torch.vtensor<[],i1> + // CHECK: %[[CONST_ARR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>) : !torch.vtensor<[5],f32> + // CHECK: %[[ONE_T:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[ONE_0:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_ONE_T:.*]] = torch.aten.add.Tensor %[[ITER_NUM_T]], %[[ONE_T]], %[[ONE_0]] : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ZERO_T:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[ZERO_0:.*]] = torch.constant.int 0 + // CHECK: %[[ITER_NUM_RT:.*]] = torch.aten.unsqueeze %[[ITER_NUM_T]], %[[ZERO_0]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ZERO_1:.*]] = torch.constant.int 0 + // CHECK: %[[ADD_ONE_RT:.*]] = torch.aten.unsqueeze %[[ADD_ONE_T]], %[[ZERO_1]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[NONE_2:.*]] = torch.constant.none + // CHECK: %[[ONE_1:.*]] = torch.constant.int 1 + // CHECK: %[[ONE_SIZE_LIST:.*]] = torch.prim.ListConstruct %[[ONE_1]] : (!torch.int) -> !torch.list + // CHECK: %[[ONES_T:.*]] = torch.aten.ones %[[ONE_SIZE_LIST]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> + // CHECK: %[[ZERO_2:.*]] = torch.constant.int 0 + // CHECK: %[[ZERO_3:.*]] = torch.constant.int 0 + // CHECK: %[[ZERO_T_1:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITER_NUM_INDEXED:.*]] = torch.aten.index_select %[[ITER_NUM_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITER_NUM_INT:.*]] = torch.aten.item %[[ITER_NUM_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INC_INDEXED:.*]] = torch.aten.index_select %[[ADD_ONE_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[INC_INT:.*]] = torch.aten.item %[[INC_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_INDEX_T:.*]] = torch.aten.index_select %[[ONES_T]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[INDEX_INT:.*]] = torch.aten.item %[[SLICE_INDEX_T]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INPUT_SLICE:.*]] = torch.aten.slice.Tensor %[[CONST_ARR]], %[[ZERO_3]], %[[ITER_NUM_INT]], %[[INC_INT]], %[[INDEX_INT]] : !torch.vtensor<[5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[ONE_2:.*]] = torch.constant.int 1 + // CHECK: %[[INTERM_RES:.*]] = torch.aten.add.Tensor %[[LCD_1_BODY]], %[[INPUT_SLICE]], %[[ONE_2]] : !torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[INTERM_RES]] : !torch.vtensor<[1],f32>) + // CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + // CHECK: return %[[LOOP]] : !torch.vtensor<[1],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.Loop"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],i1>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> { + ^bb0(%arg3: !torch.vtensor<[],si64>, %arg4: !torch.vtensor<[],i1>, %arg5: !torch.vtensor<[1],f32>): + %1 = torch.operator "onnx.Identity"(%arg4) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[],i1> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>} : () -> !torch.vtensor<[5],f32> + %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %4 = torch.operator "onnx.Add"(%arg3, %3) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> + %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %6 = torch.operator "onnx.Unsqueeze"(%arg3, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64> + %7 = torch.operator "onnx.Unsqueeze"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64> + %8 = torch.operator "onnx.Slice"(%2, %6, %7) : (!torch.vtensor<[5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32> + %9 = torch.operator "onnx.Add"(%arg5, %8) : (!torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1, %9 : !torch.vtensor<[],i1>, !torch.vtensor<[1],f32> + } + return %0 : !torch.vtensor<[1],f32> +} From 6d0ca499e678f5913914d5cc3cabd460e483ab85 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:33:41 -0700 Subject: [PATCH 14/30] [ONNX] Add OnnxToTorch support for ReverseSequence (#3495) --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 78 ++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 112 ++++++++++++++++++ 2 files changed, 190 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 58d8397ee67c0..ec4a71294b0ee 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3564,4 +3564,82 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, permutedStft); return success(); }); + patterns.onOp( + "ReverseSequence", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, sequenceLens; + int64_t batchAxis, timeAxis; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(sequenceLens, 1) || + binder.s64IntegerAttr(batchAxis, "batch_axis", 1) || + binder.s64IntegerAttr(timeAxis, "time_axis", 0) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTy = cast(input.getType()); + SmallVector inputShape(inputTy.getSizes()); + auto dtype = resultType.getDtype(); + + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value batchAxisVal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(batchAxis)); + Value timeAxisVal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(timeAxis)); + + SmallVector sliceShape(inputShape); + sliceShape[batchAxis] = 1; + auto sliceType = + rewriter.getType(sliceShape, dtype); + SmallVector flipShape(sliceShape); + flipShape[timeAxis] = Torch::kUnknownSize; + auto flipType = + rewriter.getType(flipShape, dtype); + auto scalarTensorType = rewriter.getType( + ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)); + + for (int i = 0; i < inputShape[batchAxis]; i++) { + // slice i iterating on batch axis + Value k = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value end = + rewriter.create(binder.getLoc(), k, cstOne); + Value sliceBatch = rewriter.create( + binder.getLoc(), sliceType, input, batchAxisVal, k, end, cstOne); + + // get sequence length and slice the reversing part + Value kTensor = rewriter.create( + binder.getLoc(), scalarTensorType, k); + Value sel = rewriter.create( + binder.getLoc(), scalarTensorType, sequenceLens, cstZero, + kTensor); + Value len = rewriter.create( + binder.getLoc(), rewriter.getType(), sel); + Value sliceTime = rewriter.create( + binder.getLoc(), flipType, sliceBatch, timeAxisVal, cstZero, len, + cstOne); + // flip the sliced reversing tensor + Value dims = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{timeAxisVal}); + Value flip = rewriter.create( + binder.getLoc(), flipType, sliceTime, dims); + + // embeds the reversed tensor to the input + Value embedTime = rewriter.create( + binder.getLoc(), sliceType, sliceBatch, flip, timeAxisVal, + /*start=*/cstZero, /*end=*/len, /*step=*/cstOne); + input = rewriter.create( + binder.getLoc(), resultType, input, embedTime, batchAxisVal, + /*start=*/k, /*end=*/end, /*step=*/cstOne); + } + + rewriter.replaceOp(binder.op, input); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index d611823f9052d..095ee8c77b921 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2663,3 +2663,115 @@ func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !t %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> return %0 : !torch.vtensor<[1,15,9,2],f32> } + +// ----- + +// CHECK-LABEL: @test_reversesequence_batch +func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[C0_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %[[SLICE]], %[[C1_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %[[SLICE_0]], %[[DIM]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED:.*]] = torch.aten.slice_scatter %[[SLICE]], %[[FLIP]], %[[C1_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[EMBED_0:.*]] = torch.aten.slice_scatter %arg0, %[[EMBED]], %[[C0_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[C1_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.Tensor %[[EMBED_0]], %[[C0_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_2:.*]] = torch.aten.slice.Tensor %[[SLICE_1]], %[[C1_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM_0:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[SLICE_2]], %[[DIM_0]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED_1:.*]] = torch.aten.slice_scatter %[[SLICE_1]], %[[FLIP_0]], %[[C1_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[EMBED_2:.*]] = torch.aten.slice_scatter %[[EMBED_0]], %[[EMBED_1]], %[[C0_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[C2]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_3:.*]] = torch.aten.slice.Tensor %[[EMBED_2]], %[[C0_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX_1:.*]] = torch.prim.NumToTensor.Scalar %[[C2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_1:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_4:.*]] = torch.aten.slice.Tensor %[[SLICE_3]], %[[C1_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM_1:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_1:.*]] = torch.aten.flip %[[SLICE_4]], %[[DIM_1]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED_3:.*]] = torch.aten.slice_scatter %[[SLICE_3]], %[[FLIP_1]], %[[C1_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[EMBED_4:.*]] = torch.aten.slice_scatter %[[EMBED_2]], %[[EMBED_3]], %[[C0_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[ADD_2:.*]] = torch.aten.add.int %[[C3]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_5:.*]] = torch.aten.slice.Tensor %[[EMBED_4]], %[[C0_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX_2:.*]] = torch.prim.NumToTensor.Scalar %[[C3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_2:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_6:.*]] = torch.aten.slice.Tensor %[[SLICE_5]], %[[C1_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM_2:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_2:.*]] = torch.aten.flip %[[SLICE_6]], %[[DIM_2]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED_5:.*]] = torch.aten.slice_scatter %[[SLICE_5]], %[[FLIP_2]], %[[C1_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: torch.aten.slice_scatter %[[EMBED_4]], %[[EMBED_5]], %[[C0_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + %0 = torch.operator "onnx.ReverseSequence"(%arg0, %arg1) {torch.onnx.batch_axis = 0 : si64, torch.onnx.time_axis = 1 : si64} : (!torch.vtensor<[4,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_reversesequence_time +func.func @test_reversesequence_time(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[C0_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %[[SLICE]], %[[C0_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %[[SLICE_0]], %[[DIM]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED:.*]] = torch.aten.slice_scatter %[[SLICE]], %[[FLIP]], %[[C0_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[EMBED_0:.*]] = torch.aten.slice_scatter %arg0, %[[EMBED]], %[[C1_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[C1_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.Tensor %[[EMBED_0]], %[[C1_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_2:.*]] = torch.aten.slice.Tensor %[[SLICE_1]], %[[C0_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM_0:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[SLICE_2]], %[[DIM_0]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED_1:.*]] = torch.aten.slice_scatter %[[SLICE_1]], %[[FLIP_0]], %[[C0_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[EMBED_2:.*]] = torch.aten.slice_scatter %[[EMBED_0]], %[[EMBED_1]], %[[C1_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[C2]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_3:.*]] = torch.aten.slice.Tensor %[[EMBED_2]], %[[C1_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX_1:.*]] = torch.prim.NumToTensor.Scalar %[[C2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_1:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_4:.*]] = torch.aten.slice.Tensor %[[SLICE_3]], %[[C0_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM_1:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_1:.*]] = torch.aten.flip %[[SLICE_4]], %[[DIM_1]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED_3:.*]] = torch.aten.slice_scatter %[[SLICE_3]], %[[FLIP_1]], %[[C0_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[EMBED_4:.*]] = torch.aten.slice_scatter %[[EMBED_2]], %[[EMBED_3]], %[[C1_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[ADD_2:.*]] = torch.aten.add.int %[[C3]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_5:.*]] = torch.aten.slice.Tensor %[[EMBED_4]], %[[C1_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX_2:.*]] = torch.prim.NumToTensor.Scalar %[[C3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_2:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_6:.*]] = torch.aten.slice.Tensor %[[SLICE_5]], %[[C0_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM_2:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_2:.*]] = torch.aten.flip %[[SLICE_6]], %[[DIM_2]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED_5:.*]] = torch.aten.slice_scatter %[[SLICE_5]], %[[FLIP_2]], %[[C0_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: torch.aten.slice_scatter %[[EMBED_4]], %[[EMBED_5]], %[[C1_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + %0 = torch.operator "onnx.ReverseSequence"(%arg0, %arg1) {torch.onnx.batch_axis = 1 : si64, torch.onnx.time_axis = 0 : si64} : (!torch.vtensor<[4,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} From 1f73895f93e03b1804a8a52e82a0c3395b2c1a49 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 27 Jun 2024 19:28:02 -0700 Subject: [PATCH 15/30] [torch-mlir] bump to llvm/llvm-project@9b78ddf3b2abfb3e (#3491) This bump triggered an upstream assert. Includes a WAR for #3506. Also includes several things I needed to do to repro: * When TORCH_MLIR_TEST_CONCURRENCY=1, test runs will be printed. * Added TORCH_MLIR_TEST_VERBOSE=1 handling to enable verbose mode (useful on CI). --------- Co-authored-by: Stella Laurenzo --- docs/development.md | 14 ++++++++++++++ externals/llvm-project | 2 +- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 13 +++++++------ lib/Dialect/Torch/IR/TorchOps.cpp | 6 ++++-- projects/pt1/e2e_testing/main.py | 4 ++++ .../pt1/python/torch_mlir_e2e_test/framework.py | 14 +++++++++++++- python/torch_mlir/compiler_utils.py | 3 +++ 7 files changed, 46 insertions(+), 10 deletions(-) diff --git a/docs/development.md b/docs/development.md index 154b398f1ca16..771c4fcbef0e6 100644 --- a/docs/development.md +++ b/docs/development.md @@ -429,6 +429,20 @@ cd projects/pt1 python -m e2e_testing.main -f 'AtenEmbeddingBag' ``` +The default mode of running tests uses the multi-processing framework and is +not tolerant of certain types of errors. If encountering native crashes/hangs, +enable debug variables to run sequentially/in-process with more verbosity: + +``` +export TORCH_MLIR_TEST_CONCURRENCY=1 +export TORCH_MLIR_TEST_VERBOSE=1 +``` + +In this way, you can run under `gdb`, etc and get useful results. Having env +vars like this makes it easy to set in GH action files, etc. Note that the +verbose flags are very verbose. Basic sequential progress reports will be +printed regardless when not running in parallel. + ## Running unit tests. To run all of the unit tests, run: diff --git a/externals/llvm-project b/externals/llvm-project index 5207632f8698a..9b78ddf3b2abf 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 5207632f8698a2fab0c4cdcdf2f7ad9aaf96e06f +Subproject commit 9b78ddf3b2abfb3e2063e3dad2a326f5eabc1618 diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 05258f50617f8..943eda4239458 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -46,16 +46,17 @@ using namespace mlir::torch::TMTensor; static void getEffectsImpl( SmallVectorImpl> &effects, - ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) { - for (Value value : results) { + ResultRange results, ArrayRef inputBuffers, + ArrayRef outputBuffers) { + for (OpResult value : results) { effects.emplace_back(MemoryEffects::Allocate::get(), value, SideEffects::DefaultResource::get()); } - for (Value value : inputBuffers) { + for (OpOperand *value : inputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } - for (Value value : outputBuffers) { + for (OpOperand *value : outputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, @@ -1121,8 +1122,8 @@ bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) { void OP_NAME::getEffects( \ SmallVectorImpl> \ &effects) { \ - SmallVector inputBuffers = getInputBufferOperands(); \ - SmallVector outputBuffers = getOutputBufferOperands(); \ + OpOperandVector inputBuffers = getInputBufferOperands(); \ + OpOperandVector outputBuffers = getOutputBufferOperands(); \ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \ outputBuffers); \ } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c37b96c60f664..b10a0c61fb553 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2810,7 +2810,8 @@ LogicalResult CopyToNonValueTensorOp::inferReturnTypes( void CopyToNonValueTensorOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Allocate::get(), getResult()); + effects.emplace_back(MemoryEffects::Allocate::get(), + getOperation()->getOpResult(0)); } //===----------------------------------------------------------------------===// @@ -2837,7 +2838,8 @@ LogicalResult CopyToValueTensorOp::inferReturnTypes( void CopyToValueTensorOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getOperand()); + effects.emplace_back(MemoryEffects::Read::get(), + &getOperation()->getOpOperand(0)); } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index e9468ee919dae..4d0eb48618c1b 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -7,6 +7,10 @@ import re import sys +import torch + +torch.device("cpu") + from torch_mlir_e2e_test.framework import run_tests from torch_mlir_e2e_test.reporting import report_results from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index 42f4b5415d371..56c2e91ae4c60 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -358,6 +358,15 @@ def run_tests( if env_concurrency > 0: num_processes = min(num_processes, env_concurrency) + try: + env_verbose = os.getenv("TORCH_MLIR_TEST_VERBOSE", "0") + if env_verbose is not None: + verbose = bool(int(env_verbose)) + except ValueError as e: + raise ValueError( + "Bad value for TORCH_MLIR_TEST_VERBOSE env var: " "Expected integer." + ) from e + # TODO: We've noticed that on certain 2 core machine parallelizing the tests # makes the llvm backend legacy pass manager 20x slower than using a # single process. Need to investigate the root cause eventually. This is a @@ -375,7 +384,10 @@ def run_tests( # seems to cause a cascade of failures resulting in undecipherable error # messages. if num_processes == 1 or sequential: - return [compile_and_run_test(test, config, verbose) for test in tests] + print("Running tests sequentially with progress status") + for test in tests: + print(f"*** RUNNING TEST: {test.unique_name} ***") + compile_and_run_test(test, config, verbose) # This is needed because autograd does not support crossing process # boundaries. diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index 4e5a2f8f8c07f..c1315abd47f96 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -40,6 +40,9 @@ def run_pipeline_with_repro_report( ) # Lower module in place to make it ready for compiler backends. with module.context as ctx: + # TODO(#3506): Passes can emit errors but not signal failure, + # which causes a native assert. + ctx.emit_error_diagnostics = True pm = PassManager.parse(pipeline) if enable_ir_printing: ctx.enable_multithreading(False) From 23e3c0b5d268b193e46e50df6db6f36ea42eaa0b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 27 Jun 2024 20:27:11 -0700 Subject: [PATCH 16/30] Bump llvm to d16b21b17d13ecd88a068bb803df43e53d3b04ba. (#3508) --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 9b78ddf3b2abf..d16b21b17d13e 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 9b78ddf3b2abfb3e2063e3dad2a326f5eabc1618 +Subproject commit d16b21b17d13ecd88a068bb803df43e53d3b04ba From 7e6d76e997f438fe5bf540eaba9e0ee069bd9f6e Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Fri, 28 Jun 2024 16:06:52 +0200 Subject: [PATCH 17/30] [Torch] Fix torch.constant.int operation parsing (#3476) Due to the custom operation parser, the print and parser were expecting two different forms. One having the dictionary before the value and the other after. Following the format of the other constants ops, the constant.int will follow the `value attr-dict` format. Updated the parser accordingly. --- lib/Dialect/Torch/IR/TorchOps.cpp | 4 ++-- test/Dialect/Torch/ops.mlir | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b10a0c61fb553..b10111f787638 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2882,11 +2882,11 @@ void ConstantDeviceOp::getAsmResultNames( ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) { Builder builder(result.getContext()); result.addTypes(builder.getType()); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); int64_t value; if (parser.parseInteger(value)) return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); result.addAttribute("value", builder.getI64IntegerAttr(value)); return success(); } diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 1fdbf6e1d7d34..29ab52f9dab0d 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -93,6 +93,9 @@ func.func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int { // CHECK: %int-3 = torch.constant.int -3 %int-3 = torch.constant.int -3 +// CHECK: %int5 = torch.constant.int 5 {test = "value"} +%int5 = torch.constant.int 5 {test = "value"} + // CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 %float1.000000e00 = torch.constant.float 1.000000e+00 // CHECK: %float-1.000000e00 = torch.constant.float -1.000000e+00 From 5a627c46b76f8cdc737aef3bda1b910836e33d88 Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Fri, 28 Jun 2024 20:08:43 +0530 Subject: [PATCH 18/30] onnx.DFT basic support (#3463) - adds support for DFT v20 on the FFT and IFFT path - adds required skeleton code for IFFT ops to be recognised in TMlir --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ++++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 91 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 47 ++++++++++ .../build_tools/abstract_interp_lib_gen.py | 20 ++++ .../build_tools/torch_ods_gen.py | 1 + .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 48 ++++++++++ 6 files changed, 233 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index be5bc56d7fe78..ae5f56aead12e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12418,6 +12418,32 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$n, + Torch_IntType:$dim, + AnyTorchOptionalStringType:$norm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFftIfftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenFftIfftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 446298e89b336..a5cdc10208880 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2728,4 +2728,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); + + patterns.onOp( + "DFT", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value inTensor, dftLength, axis; + Torch::ValueTensorType resultType; + int64_t inverse, onesided; + if (binder.tensorOperandAtIndex(inTensor, 0) || + binder.s64IntegerAttr(inverse, "inverse", 0) || + binder.s64IntegerAttr(onesided, "onesided", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "Input Tensor / attrs / resultType bind failed"); + if (!binder.tensorOperandAtIndex(dftLength, 1)) { + // Convert to int and pass as n + dftLength = rewriter.create( + binder.getLoc(), rewriter.getType(), dftLength); + } else { + // Default for torch is None + dftLength = rewriter.create(binder.getLoc()); + } + // Default is same for onnx and torch + if (!binder.tensorOperandAtIndex(axis, 2)) { + // convert to int and pass to dims + axis = rewriter.create( + binder.getLoc(), rewriter.getType(), axis); + } else { + // Default in torch is -1 and onnx is -2 (since -1 is for real / img) + axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(-2)); + } + + if (onesided == 1) + return rewriter.notifyMatchFailure(binder.op, + "Unsupported option : onesided"); + // norm default string attr + Value norm = rewriter.create( + binder.getLoc(), rewriter.getStringAttr(Twine("backward"))); + // Convert from [....., 2] complex number repr for fft consumption. + Torch::ValueTensorType inType = + binder.toValidTensorType(inTensor.getType()); + int64_t lastIndex = inType.getSizes().back(); + if (lastIndex != 1 && lastIndex != 2) + return rewriter.notifyMatchFailure( + binder.op, + "Expected input tensor to have dims [..., 1] or [..., 2]"); + + // concat with zeros to make it [..., 2] + Value inForComplexVal = inTensor; + ArrayRef inForComplexSizes = inType.getSizes().drop_back(); + if (lastIndex == 1) { + Value constZeroVal = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0)); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value constZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value padSizeList = + rewriter + .create( + binder.getLoc(), + Torch::ListType::get(rewriter.getType()), + SmallVector({constZero, constOne})) + .getResult(); + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("constant")); + SmallVector resSize(inForComplexSizes); + resSize.push_back(2); + inForComplexVal = rewriter.create( + binder.getLoc(), + inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()), + inTensor, padSizeList, modeVal, constZeroVal); + } + Type inComplexTensorType = Torch::ValueTensorType::get( + binder.op->getContext(), inForComplexSizes, + mlir::ComplexType::get(inType.getDtype())); + Value inComplexTensor = rewriter.create( + binder.getLoc(), inComplexTensorType, inForComplexVal); + Value ftOp; + if (inverse == 0) { + ftOp = rewriter.create( + binder.getLoc(), inComplexTensorType, inComplexTensor, + /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); + } else { + ftOp = rewriter.create( + binder.getLoc(), inComplexTensorType, inComplexTensor, + /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); + } + rewriter.replaceOpWithNewOp(binder.op, + resultType, ftOp); + return success(); + }); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 6974636c0e86c..b05e1051c36f2 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10369,6 +10369,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %14 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" @@ -11984,6 +11987,50 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 0b356cc3412cf..b3d7ec5a9dec8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2038,6 +2038,9 @@ def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = return out +def aten〇fft_ifft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: + return self + class DummyClassType: def __init__(self): pass @@ -3406,6 +3409,23 @@ def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length assert False, "Unsupported dtype" +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16})) +def aten〇fft_ifft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype): + return self_dtype + elif self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 90d3e10546849..fe700d2923e36 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -910,6 +910,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)" ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit( "aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)" diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 4b03fcceeec18..cf92c04d836f0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2480,3 +2480,51 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch %0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> return %0 : !torch.vtensor<[1,1,5,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_dft_fft +func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, + // CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SIG_LEN:.*]] = torch.constant.none + // CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NORM:.*]] = torch.constant.str "backward" + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "constant" + // CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32> + // CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex> + // CHECK: %[[FFT_CMPLX:.*]] = torch.aten.fft_fft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex> + // CHECK: %[[FFT_RES_REAL:.*]] = torch.aten.view_as_real %[[FFT_CMPLX]] : !torch.vtensor<[10,10],complex> -> !torch.vtensor<[10,10,2],f32> + // CHECK: return %[[FFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> + return %0 : !torch.vtensor<[10,10,2],f32> +} + +// CHECK-LABEL: func.func @test_dft_inverse_real +func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, + // CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SIG_LEN:.*]] = torch.constant.none + // CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NORM:.*]] = torch.constant.str "backward" + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "constant" + // CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32> + // CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex> + // CHECK: %[[IFFT_CMPLX:.*]] = torch.aten.fft_ifft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex> + // CHECK: %[[IFFT_RES_REAL:.*]] = torch.aten.view_as_real %[[IFFT_CMPLX]] : !torch.vtensor<[10,10],complex> -> !torch.vtensor<[10,10,2],f32> + // CHECK: return %[[IFFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) {torch.onnx.inverse = 1 : si64} : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> + return %0 : !torch.vtensor<[10,10,2],f32> +} From f75cbb4df9bbf390281946203b08eb7ceb80a778 Mon Sep 17 00:00:00 2001 From: Jiawei Wu Date: Sat, 29 Jun 2024 00:07:55 +0800 Subject: [PATCH 19/30] [torch dialect] emit aten.fmax/fmin and add decomposition patterns (#3510) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 48 +++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 24 ++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 40 +++++++++++++ .../Transforms/LowerToBackendContract.cpp | 3 + projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../build_tools/abstract_interp_lib_gen.py | 22 +++++++ .../build_tools/torch_ods_gen.py | 2 + .../test_suite/elementwise.py | 58 +++++++++++++++++++ 8 files changed, 201 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ae5f56aead12e..f4223b1f4bf72 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4671,6 +4671,54 @@ def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ }]; } +def Torch_AtenFmaxOp : Torch_Op<"aten.fmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fmax : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenFminOp : Torch_Op<"aten.fmin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fmin : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFminOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFminOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenMishOp : Torch_Op<"aten.mish", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b05e1051c36f2..f8a5409b8a707 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8940,6 +8940,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fmin\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fmax\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12471,6 +12479,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmax\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmin\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int5 = torch.constant.int 5\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7c2c29a6d720b..2086fb68afa25 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8493,6 +8493,41 @@ class DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp }; } // namespace +namespace { +// Decompose aten.fmax/fmin to aten.maximum/minimum + aten.where(nanMask) +template +class DecomposeAtenFMaxMinOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + BaseTensorType outType = cast(op.getType()); + Type nanMaskType = outType.getWithSizesAndDtype( + !outType.hasSizes() ? std::optional>() + : llvm::ArrayRef(outType.getSizes()), + rewriter.getI1Type()); + + Value self = op.getSelf(); + Value other = op.getOther(); + + Value normalResult = + rewriter.create(loc, outType, self, other).getResult(); + Value selfIsNan = + rewriter.create(loc, nanMaskType, self).getResult(); + Value otherIsNan = + rewriter.create(loc, nanMaskType, other) + .getResult(); + normalResult = rewriter.create( + loc, outType, otherIsNan, self, normalResult); + rewriter.replaceOpWithNewOp(op, outType, selfIsNan, other, + normalResult); + + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -8732,6 +8767,11 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenFMaxMinOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenFMaxMinOp>(patterns); + GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 21e2abb2474e9..15bebfc64390e 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -544,6 +544,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8db4414bbb204..19acd4d862289 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1673,6 +1673,8 @@ "ElementwiseFlattenBroadcastModule_basic", "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", + "ElementwiseFmaxModule_basic", + "ElementwiseFminModule_basic", "ElementwiseGeFloatIntScalarModule_basic", "ElementwiseGeFloatScalarModule_basic", "ElementwiseGeIntScalarModule_basic", @@ -2215,6 +2217,8 @@ "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", + "ElementwiseFminModule_basic", + "ElementwiseFmaxModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "PixelShuffleModuleStaticRank4Float32_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index b3d7ec5a9dec8..9052c8cc2057a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1403,6 +1403,12 @@ def aten〇minimum〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇maximum〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇fmin〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + +def aten〇fmax〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -3655,6 +3661,22 @@ def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: T dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇fmax〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇fmin〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) + # Different width diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fe700d2923e36..c3cb95dd7fbeb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -463,6 +463,8 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") + emit("aten::fmax : (Tensor, Tensor) -> (Tensor)") + emit("aten::fmin : (Tensor, Tensor) -> (Tensor)") emit("aten::mish : (Tensor) -> (Tensor)") emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") emit( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index ce000264efec3..b448bbaa49f6a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1440,6 +1440,64 @@ def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseFmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.fmax(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmaxModule()) +def ElementwiseFmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + module.forward(tu.rand(4), torch.tensor([1.0, torch.nan, -0.5, -0.3])) + module.forward( + torch.tensor([0.8, torch.nan, torch.nan, -0.3]), + torch.tensor([1.0, torch.nan, -0.4, torch.nan]), + ) + + +# ============================================================================== + + +class ElementwiseFminModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.fmin(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFminModule()) +def ElementwiseFminModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + module.forward(tu.rand(4), torch.tensor([1.0, torch.nan, -0.5, -0.3])) + module.forward( + torch.tensor([0.8, torch.nan, torch.nan, -0.3]), + torch.tensor([1.0, torch.nan, -0.4, torch.nan]), + ) + + +# ============================================================================== + + class ElementwiseMaxOtherModule(torch.nn.Module): def __init__(self): super().__init__() From a1c4089e71c8be1577217930bd9dddf13a6c76f5 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:20:29 -0400 Subject: [PATCH 20/30] Fix unused variable warning from assertion variable (#3512) Inlines a variable into an assertion that is not used elsewhere to fix build warnings. --- lib/Conversion/TorchToLinalg/Utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 46b51558f13d6..6ef947d890cdd 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -69,14 +69,14 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( int unpaddedDims, Value pad) { assert(isa(input.getType()) && "input must be RankedTensorType"); - unsigned int inRank = cast(input.getType()).getRank(); Location loc = op->getLoc(); SmallVector inputDims = getTensorSizes(b, loc, input); Value c0 = b.create(loc, b.getI64IntegerAttr(0)); SmallVector paddingIncludingUnchanged(unpaddedDims, c0); paddingIncludingUnchanged.append(padding); - assert(unpaddedDims + padding.size() == inRank && + assert(static_cast(unpaddedDims + padding.size()) == + cast(input.getType()).getRank() && "sum of unpaddedDims and padding.size() must equal to inputRank"); for (auto pad = paddingIncludingUnchanged.begin(); pad < paddingIncludingUnchanged.end(); pad++) From af236dab66778ab722b7c105c11ea710599f100f Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 28 Jun 2024 11:59:51 -0500 Subject: [PATCH 21/30] Add support for multiple dynamic reassociation dims for unflatten.int (#3504) Addresses an issue with onnx.Gather lowering to linalg: The builder for tensor.expand_shape, without an explicitly provided output shape, fails to infer an output shape in the case of multiple dynamic reassociation dims. I tried adding the output shape explicitly for tensor.expand_shape, but ran into compilation issues later on (see ). This PR adds support by lowering this op to tensor.reshape when multiple dynamic reassociation dims are provided. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 72 +++++++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 11 --- test/Conversion/TorchToLinalg/view.mlir | 27 +++++++ 3 files changed, 84 insertions(+), 26 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index dc8b5d4310023..475e0ec407d49 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -661,7 +661,8 @@ class ConvertAtenUnflattenIntOp "Expected input type having sizes"); } int inputRank = inputTensorType.getSizes().size(); - int outputRank = outputTensorType.getSizes().size(); + auto outputSizes = outputTensorType.getSizes(); + int outputRank = outputSizes.size(); int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) @@ -675,23 +676,64 @@ class ConvertAtenUnflattenIntOp auto sizesOp = op.getSizes().getDefiningOp(); int numSizes = sizesOp.getNumOperands(); - SmallVector reassociations(inputRank); - if (inputRank > 0) { - for (int i = 0; i < dimInt; ++i) - reassociations[i].push_back(i); - - for (int i = 0; i < numSizes; ++i) - reassociations[dimInt].push_back(i + dimInt); - - for (int i = dimInt + numSizes; i < outputRank; ++i) - reassociations[i - numSizes + 1].push_back(i); + int64_t numDynamicReassocDims = 0; + for (int64_t i = dimInt; i < dimInt + numSizes; i++) { + if (outputSizes[i] == Torch::kUnknownSize) + numDynamicReassocDims++; } + SmallVector reassocSizes; + if (!getListConstructElements(op.getSizes(), reassocSizes) && + numDynamicReassocDims > 1) + return rewriter.notifyMatchFailure( + op, "Must be able to either infer expansion dims, or retrieve them " + "from list construct"); + auto expandTy = getTypeConverter()->convertType(outputTensorType); - auto expand = rewriter - .create( - loc, expandTy, adaptor.getSelf(), reassociations) - .getResult(); + Value expand; + // When there are less than two dynamic reassociation dims, this will lower + // to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. + // TODO: in the numDynamicReassocDims >= 2 case, lower to expand_shape with + // explicitly provided outputShape once + // https://github.com/iree-org/iree/issues/17760 is resolved. + if (numDynamicReassocDims < 2) { + SmallVector reassociations(inputRank); + if (inputRank > 0) { + for (int i = 0; i < dimInt; ++i) + reassociations[i].push_back(i); + for (int i = 0; i < numSizes; ++i) + reassociations[dimInt].push_back(i + dimInt); + for (int i = dimInt + numSizes; i < outputRank; ++i) + reassociations[i - numSizes + 1].push_back(i); + } + expand = rewriter + .create( + loc, expandTy, adaptor.getSelf(), reassociations) + .getResult(); + } else { + reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + reassocSizes); + SmallVector inputShape = + getTensorSizes(rewriter, loc, adaptor.getSelf()); + inputShape = castIndexVectorToInt64Vector(rewriter, loc, inputShape); + SmallVector outputShape(inputShape.begin(), + inputShape.begin() + dimInt); + if (inputRank > 0) { + for (int i = 0; i < numSizes; ++i) + outputShape.push_back(reassocSizes[i]); + for (int i = dimInt + numSizes; i < outputRank; ++i) + outputShape.push_back(inputShape[i - numSizes + 1]); + } + + RankedTensorType shapeType = RankedTensorType::get( + ArrayRef{outputRank}, rewriter.getIntegerType(64)); + Value shapeValue = + rewriter.create(loc, shapeType, outputShape); + expand = rewriter + .create(loc, expandTy, adaptor.getSelf(), + shapeValue) + .getResult(); + } rewriter.replaceOp(op, expand); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 19acd4d862289..bc99fde51b785 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2197,17 +2197,6 @@ ONNX_XFAIL_SET = { # Failure - cast error "PermuteNegativeIndexModule_basic", - # Failure - expand multiple dynamic dims - "EmbeddingModuleF16_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - "IndexTensorSelectDimModule_basic", # Failure - incorrect numerics "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 3d265a308a0d0..2da7c0b74fc27 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -281,3 +281,30 @@ func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3], %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[10,?,2,3],f32>, !torch.list -> !torch.vtensor<[2,5,?,6],f32> return %1 : !torch.vtensor<[2,5,?,6],f32> } + +// ----- + +// this is to check a path for unflatten.int with two dynamic reassociation dims +// the IR here is generated from the onnx.Gather conversion +// CHECK-LABEL: @gather_graph +// CHECK: %[[fromelt:.*]] = tensor.from_elements +// CHECK-SAME: tensor<3xi64> +// CHECK: %[[reshape:.*]] = tensor.reshape +// CHECK-SAME: (tensor, tensor<3xi64>) -> tensor +func.func @gather_graph(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?,3],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %int-1 = torch.constant.int -1 + %int5 = torch.constant.int 5 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.lt.Scalar %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],i1> + %1 = torch.aten.add.Scalar %arg1, %int5, %int1 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?],si64> + %2 = torch.aten.where.self %0, %1, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + %3 = torch.aten.size.int %2, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.aten.size.int %2, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %5 = torch.prim.ListConstruct %3, %4 : (!torch.int, !torch.int) -> !torch.list + %6 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %7 = torch.aten.view %2, %6 : !torch.vtensor<[?,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + %8 = torch.aten.index_select %arg0, %int0, %7 : !torch.vtensor<[5,3],f32>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,3],f32> + %9 = torch.aten.unflatten.int %8, %int0, %5 : !torch.vtensor<[?,3],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,3],f32> + return %9 : !torch.vtensor<[?,?,3],f32> +} From 6fece25ff3203bbc538756beb83fd513c19bcd7d Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 28 Jun 2024 10:18:36 -0700 Subject: [PATCH 22/30] [torch-mlir][sparse] add decomposition features to sparse compiler (#3505) Fixes https://github.com/llvm/torch-mlir/issues/3499 --- python/torch_mlir/extras/fx_decomp_util.py | 1 + test/python/fx_importer/sparse_test.py | 25 ++++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 868dc26c6cb95..8dddede2d9ccf 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -49,6 +49,7 @@ torch.ops.aten.nan_to_num.default, torch.ops.aten.unbind, torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten.diag, ] diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 7c7198ef6f61a..699d57cb2b0d2 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -12,6 +12,7 @@ import torch.nn as nn import numpy as np +from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_importer import FxImporter from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir import ir @@ -106,6 +107,9 @@ def sparse_export( # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). prog = torch.export.export(f, dargs, kwargs) + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) # Annotate sparse arguments in the graph and apply some very # basic propagation rules for sparsity. specs = prog.graph_signature.input_specs @@ -120,7 +124,6 @@ def sparse_export( node.meta["sparsity"] = sparse_metadata(args[k]) k = k + 1 elif node.op == "call_function": - # TODO: use upstream _opname implementation when available opname = node.target._schema.name.split("::")[1] # Zero preserving elt-wise unary op. if opname in {"abs", "neg", "relu", "sin"}: @@ -131,7 +134,7 @@ def sparse_export( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 ) # TODO: Uncomment this to hack sparsity into the network. - # elif opname == "_to_dense": + # elif opname == "_to_dense" or opname == "to_dense": # # hack (assumes we never really want the to_dense for now) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) elif opname == "select" and node.args[0].meta.get("sparsity", None): @@ -176,8 +179,8 @@ def sparse_jit(f, *args, **kwargs): compiled = backend.compile(module) invoker = backend.load(compiled) xargs = [] - # Prepare the buffer parameters (assume all dense). - # TODO: filters out scalar arguments, anything else? + # Prepare all the named buffer parameters (assume all dense). + # All scalar arguments are filtered out since they appear inline. params = dict(f.named_buffers(remove_duplicate=True)) params_flat, params_spec = torch.utils._pytree.tree_flatten(params) for p in params_flat: @@ -339,6 +342,7 @@ def forward(self, x, v): @run # +# CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, @@ -440,7 +444,7 @@ def forward(self, x): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) - # TODO: make this work + # TODO: make this work in MLIR # res3 = sparse_jit(net, batch_input) print("torch.sparse") print(res1) @@ -657,7 +661,14 @@ def forward(self, X): # CHECK: [0.1321, 0.2724, 0.2105, 0.3851], # CHECK: [0.2478, 0.3439, 0.1898, 0.2185], # CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) +# +# TODO: first row looks suspect... +# # CHECK: torch.mlir +# CHECK: {{\[}}[0. 0. 0. 0. ] +# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418] +# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ] +# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}} # def test_sparse_feature_scaling(): class Scale(nn.Module): @@ -678,11 +689,11 @@ def forward(self, F): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(f) - # TODO: make this work - # res2 = sparse_jit(net, f) + res2 = sparse_jit(net, f) print("torch.sparse") print(res1) print("torch.mlir") + print(res2) @run From 3915db0a860daf4f3d4046a622890c2e2ee0624b Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:47:29 -0700 Subject: [PATCH 23/30] [ONNX] Add OnnxToTorch support for CenterCropPad (#3496) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 123 ++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 139 ++++++++++++++++++ 2 files changed, 262 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index a5cdc10208880..401cfb0894be7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/FormatVariadic.h" +#include using namespace mlir; using namespace mlir::torch; @@ -729,6 +730,128 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, maxExpression, minExpression, constantOne); return success(); }); + patterns.onOp( + "CenterCropPad", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, shape; + if (binder.tensorOperands(input, shape) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTy = cast(input.getType()); + SmallVector inputShape(inputTy.getSizes()); + SmallVector resultShape(resultType.getSizes()); + int64_t rank = inputShape.size(); + + SmallVector axes, defaultAxes(rank); + std::iota(defaultAxes.begin(), defaultAxes.end(), 0); + if (binder.s64IntegerArrayAttr(axes, "axes", defaultAxes)) { + return failure(); + } + int64_t axesSize = axes.size(); + + Value none = rewriter.create(binder.getLoc()); + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + auto scalarTensorType = rewriter.getType( + ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)); + + int64_t lastChangeDim = 0; + llvm::SmallVector interShape(inputShape); + for (int i = 0; i < rank; i++) { + if (inputShape[i] != resultShape[i]) { + interShape[i] = -1; + lastChangeDim = i; + } + if (interShape[i] == ShapedType::kDynamic) + interShape[i] = Torch::kUnknownSize; + } + auto interType = rewriter.getType( + interShape, resultType.getOptionalDtype()); + + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("floor")); + for (int i = 0; i < axesSize; i++) { + if (axes[i] < 0) + axes[i] += rank; + if (inputShape[axes[i]] == resultShape[axes[i]]) + continue; + + auto opType = axes[i] == lastChangeDim ? resultType : interType; + Value axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])); + Value k = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value kTensor = rewriter.create( + binder.getLoc(), scalarTensorType, k); + Value sel = rewriter.create( + binder.getLoc(), scalarTensorType, shape, cstZero, kTensor); + Value outputDimSize = rewriter.create( + binder.getLoc(), rewriter.getType(), sel); + Value inputDimSize = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axes[i]))); + + if (inputShape[axes[i]] > resultShape[axes[i]]) { + Value sub = rewriter.create( + binder.getLoc(), inputDimSize, outputDimSize); + Value subTensor = rewriter.create( + binder.getLoc(), scalarTensorType, sub); + Value div = rewriter.create( + binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); + Value start = rewriter.create( + binder.getLoc(), rewriter.getType(), div); + Value end = rewriter.create( + binder.getLoc(), start, outputDimSize); + input = rewriter.create( + binder.getLoc(), opType, input, axis, start, end, cstOne); + } else { + Value sub = rewriter.create( + binder.getLoc(), outputDimSize, inputDimSize); + Value subTensor = rewriter.create( + binder.getLoc(), scalarTensorType, sub); + Value div = rewriter.create( + binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); + Value start = rewriter.create( + binder.getLoc(), rewriter.getType(), div); + Value end = rewriter.create( + binder.getLoc(), start, inputDimSize); + + SmallVector zerosShapeValues; + for (int j = 0; j < rank; j++) { + if (j == axes[i]) { + zerosShapeValues.push_back(outputDimSize); + } else { + Value dimSize = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(j))); + zerosShapeValues.push_back(dimSize); + } + } + Value zerosShapeList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + zerosShapeValues); + Value zeros = rewriter.create( + binder.getLoc(), opType, zerosShapeList, none, none, none, + none); + input = rewriter.create( + binder.getLoc(), opType, zeros, input, axis, start, end, + cstOne); + } + } + + rewriter.replaceOp(binder.op, input); + return success(); + }); patterns.onOp( "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // https://onnx.ai/onnx/operators/onnx__Clip.html diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index cf92c04d836f0..bdc6beb0b0471 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2447,6 +2447,8 @@ func.func @test_col2im_dilations(%arg0: !torch.vtensor<[1,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[1,1,6,6],f32> } +// ----- + // CHECK-LABEL: func.func @test_col2im_strides func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1 @@ -2483,6 +2485,141 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch // ----- +// CHECK-LABEL: @test_center_crop_pad_crop_and_pad +func.func @test_center_crop_pad_crop_and_pad(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[STR:.*]] = torch.constant.str "floor" + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C0_3:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,10,3],f32> + // CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,10,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,10,3],f32> + %0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32> + return %0 : !torch.vtensor<[10,10,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_center_crop_pad_crop_axes_chw +func.func @test_center_crop_pad_crop_axes_chw(%arg0: !torch.vtensor<[3,20,8],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[STR:.*]] = torch.constant.str "floor" + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_0]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C1_1]] : !torch.vtensor<[3,20,8],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[3,20,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32> + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C1_3:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_3]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[SIZE_2]], %[[ITEM_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,10,9],f32> + // CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C2_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[3,10,9],f32>, !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,9],f32> + %0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [1 : si64, 2 : si64]} : (!torch.vtensor<[3,20,8],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32> + return %0 : !torch.vtensor<[3,10,9],f32> +} + +// ----- + +// CHECK-LABEL: @test_center_crop_pad_crop_negative_axes_hwc +func.func @test_center_crop_pad_crop_negative_axes_hwc(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[STR:.*]] = torch.constant.str "floor" + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C0_3:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,9,3],f32> + // CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,9,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,9,3],f32> + %0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [-3 : si64, -2 : si64]} : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32> + return %0 : !torch.vtensor<[10,9,3],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_dft_fft func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, @@ -2506,6 +2643,8 @@ func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[10,10,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_dft_inverse_real func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, From 73ba09c58738504869e65a5cf11e946facb61b92 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 29 Jun 2024 10:43:31 +0800 Subject: [PATCH 24/30] support both option -v and TORCH_MLIR_TEST_VERBOSE (#3511) so that we could run `python3 -m e2e_testing.main -v` to specify `verbose=True` --- projects/pt1/python/torch_mlir_e2e_test/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index 56c2e91ae4c60..38b027e5d31fa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -361,7 +361,7 @@ def run_tests( try: env_verbose = os.getenv("TORCH_MLIR_TEST_VERBOSE", "0") if env_verbose is not None: - verbose = bool(int(env_verbose)) + verbose = verbose or bool(int(env_verbose)) except ValueError as e: raise ValueError( "Bad value for TORCH_MLIR_TEST_VERBOSE env var: " "Expected integer." From f9fc741eeffbb45e4bfcd40f6309bc0a61c75962 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 29 Jun 2024 16:53:33 +0800 Subject: [PATCH 25/30] [Stablehlo] support aten.any.dim, aten.min.dim (#3500) * refactor `TorchToStablehlo/Reduction.cpp` * add `ConvertAtenReduceWithIndicesOp` patterns --- lib/Conversion/TorchToStablehlo/Reduction.cpp | 716 ++++++++---------- projects/pt1/e2e_testing/xfail_sets.py | 12 +- .../test_suite/reduction.py | 20 + 3 files changed, 325 insertions(+), 423 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index d8d7d43c4d24d..c9a2ad2e7ff8a 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -30,6 +30,18 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; +static SmallVector getReduceOutputShape(ArrayRef inputShape, + ArrayRef dims) { + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (size_t i = 0; i < inputShape.size(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputShape[i]); + } + } + return reduceResultShape; +} + static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); @@ -42,8 +54,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, @@ -59,8 +70,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); @@ -69,7 +79,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -77,8 +87,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())}); @@ -93,8 +102,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { APInt one(elementTy.getIntOrFloatBitWidth(), 1); auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, @@ -103,13 +111,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } if (isa(op)) { - auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); + auto constAttr = + DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); return rewriter.create(op->getLoc(), constType, constAttr); } - if (isa(op)) { - auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); + if (isa(op)) { + auto constAttr = + DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)}); return rewriter.create(op->getLoc(), constType, constAttr); } @@ -149,16 +159,17 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { @@ -174,11 +185,11 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, return reduce.getResults()[0]; } -// Util for converting AtenArgmaxOp and AtenMaxDimOp +// Util for converting AtenMaxDimOp/AtenMinDimOp static std::optional -getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, - ArrayRef inputShapeVec, int64_t dim, - size_t dimSizeIndexBits) { +createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op, + Value &input, ArrayRef inputShapeVec, + int64_t dim, size_t dimSizeIndexBits) { auto inputTy = cast(input.getType()); if (!inputTy) { return std::nullopt; @@ -199,8 +210,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } - std::vector outputShape(inputShape.begin(), inputShape.end()); - outputShape.erase(outputShape.begin() + dim); + auto outputShape = getReduceOutputShape(inputShape, {dim}); auto outputTy = RankedTensorType::get(outputShape, inputElemTy); auto outputIndexTy = RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); @@ -252,6 +262,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareLeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::LE); stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::EQ); @@ -260,11 +273,21 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, - compareGeDirectionAttr, compareTypeAttr); + Value compareResult; + if (isa(op)) { + compareResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + } else if (isa(op)) { + compareResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareLeDirectionAttr, compareTypeAttr); + } else { + op->emitError("unimplement lowering of createReduceOpReturnIndices"); + return std::nullopt; + } Value retValResult = rewriter.create( - op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + op->getLoc(), compareResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. Value compareEqResult = rewriter.create( @@ -273,16 +296,35 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( - op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( - op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + op->getLoc(), ValueRange{retValResult, retIdxResult}); } return stablehloReduceOp.getResults(); } +static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter, + Location loc, Value reduceResult, + ArrayRef inputShapeVec, + Type outType, + ArrayRef dims, + size_t dimSizeIndexBits) { + SmallVector outShapeVec(inputShapeVec); + Value one = rewriter.create( + loc, + rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); + for (auto dim : dims) { + outShapeVec[dim] = one; + } + auto outShapeTensor = + rewriter.create(loc, outShapeVec); + return rewriter.create( + loc, outType, reduceResult, outShapeTensor); +} + namespace { template class ConvertAtenReductionOp : public ConvertAtenOp { @@ -320,14 +362,6 @@ class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp { return op.emitError( "only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, - "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); - } - if (inputElemTy != outTy.getElementType()) { // use output type as computation type input = rewriter.create(op->getLoc(), input, @@ -347,7 +381,7 @@ class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp { }; template -class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { +class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp { public: using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; @@ -356,7 +390,10 @@ class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } @@ -366,12 +403,78 @@ class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { return op.emitError( "only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure( - op, - "IntegerType with bitwidth 8 unsupported in convertion to StableHLO"); + op, "non-const integer `dim` is not supported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), {dim}); + + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), {dim}, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + if (keepDim) { + const auto &options = ConvertAtenReductionOp::getOptions(); + auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim}, + options.dimSizeIndexBits); + } + rewriter.replaceOp(op, reduceResult); + return success(); + } +}; + +template +class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); } bool keepDim = false; @@ -393,19 +496,16 @@ class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { } } llvm::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputTy.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); Value reduceResult = createReduceOpWithSingleRegionOp( - op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims, + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, rewriter); - if (!reduceResult) + if (!reduceResult) { return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } if (keepDim) { const auto &options = ConvertAtenReductionOp::getOptions(); @@ -415,215 +515,104 @@ class ConvertAtenReduceKeepDimOp : public ConvertAtenReductionOp { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, - ConvertAtenReductionOp::getTypeConverter()->convertType( - op.getType()), - reduceResult, outShapeTensor); - return success(); + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims, + options.dimSizeIndexBits); } rewriter.replaceOp(op, reduceResult); return success(); } }; -} // namespace - -// AtenArgmaxOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenArgmaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported! - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenArgmaxOp to StableHLO"); - } - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); - } - dim = toPositiveDim(dim, inputTy.getRank()); - if (!isValidDim(dim, inputTy.getRank())) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } - - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(inputShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); - - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], - outShapeTensor); - return success(); - } - - rewriter.replaceOp(op, stablehloReduceResults[1]); - return success(); -} -} // namespace - -// AtenMaxDimOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMaxDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxDimOp to StableHLO"); - } +template +class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } - RankedTensorType valResultType = cast( - getTypeConverter()->convertType(op.getResult(0).getType())); - RankedTensorType idxResultType = cast( - getTypeConverter()->convertType(op.getResult(1).getType())); - Type idxElementType = idxResultType.getElementType(); - if (!isa(idxElementType)) { - return op.emitError("Aten.max.dim needs integer-like result"); - } + RankedTensorType valResultType = cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getResult(0).getType())); + RankedTensorType idxResultType = cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getResult(1).getType())); + Type idxElementType = idxResultType.getElementType(); + if (!isa(idxElementType)) { + return op.emitError("indices result should to be integer tyep"); + } - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); - } - dim = toPositiveDim(dim, inputTy.getRank()); - if (!isValidDim(dim, inputTy.getRank())) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + if (!isValidDim(dim, inputTy.getRank())) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(inputShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto inputShapeVec = *inputShapeInfo; + const auto &options = ConvertAtenReductionOp::getOptions(); + auto inputShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; - if (op.getResult(1).use_empty()) { - llvm::SmallVector outputShape(inputTy.getShape()); - outputShape.erase(outputShape.begin() + dim); - Value reduceResult = createReduceOpWithSingleRegionOp( - op, input, RankedTensorType::get(outputShape, inputElemTy), - ArrayRef{dim}, rewriter); - if (!reduceResult) - return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + if (op.getResult(1).use_empty()) { + llvm::SmallVector outputShape(inputTy.getShape()); + outputShape.erase(outputShape.begin() + dim); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get(outputShape, inputElemTy), + ArrayRef{dim}, rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, reduceResult, outShapeTensor); - rewriter.replaceOp(op, {stablehloReduceValueResult, Value()}); + if (keepDim) { + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType, + {dim}, options.dimSizeIndexBits); + } + rewriter.replaceOp(op, {reduceResult, Value()}); return success(); - } - rewriter.replaceOp(op, {reduceResult, Value()}); - return success(); - } else { - auto stablehloReduceResults = - getMaxInDim(rewriter, op, input, inputShapeVec, dim, - options.dimSizeIndexBits) - .value(); - - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, stablehloReduceResults[0], - outShapeTensor); - auto stablehloReduceIndexResult = - rewriter.create( - op->getLoc(), idxResultType, stablehloReduceResults[1], - outShapeTensor); + } else { + ValueRange stablehloReduceResults = + createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); + if (keepDim) { + stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec, + valResultType, {dim}, options.dimSizeIndexBits); + stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec, + idxResultType, {dim}, options.dimSizeIndexBits); + } rewriter.replaceOp( - op, {stablehloReduceValueResult, stablehloReduceIndexResult}); + op, {stablehloReduceResults[0], stablehloReduceResults[1]}); return success(); } - rewriter.replaceOp(op, - {stablehloReduceResults[0], stablehloReduceResults[1]}); - return success(); - } -} + }; +}; } // namespace // AtenSumDimIntListOp @@ -653,17 +642,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "Only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumDimIntListOp to StableHLO"); - } - SmallVector inputDims; SmallVector dims; - if (failed(checkNotNone(rewriter, op, op.getDim()))) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } else { @@ -675,7 +655,6 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } } - for (auto d : inputDims) { d = toPositiveDim(d, inputTy.getRank()); // Drop invalid dims @@ -683,46 +662,22 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( dims.push_back(d); } } + llvm::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputTy.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), - RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = stablehloReduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } if (keepDim) { @@ -733,23 +688,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResult(0), outShapeTensor); - return success(); + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims, + options.dimSizeIndexBits); } - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); + rewriter.replaceOp(op, reduceResult); return success(); } } // namespace @@ -789,18 +732,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "invalid dimension detected in `dim`"); } } - // Sort the dims in ascending order, making the conversion // stable with unordered dims. std::sort(dims.begin(), dims.end()); - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputRank; i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputType.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputType.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { @@ -810,36 +747,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto squareOp = rewriter.create(op->getLoc(), input, input); - auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); - if (!initValue) { - return failure(); - } - - auto reduceOp = rewriter.create( - op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), - squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = reduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputElemType); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, squareOp.getResult(), + RankedTensorType::get(reduceResultShape, inputElemType), dims, rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } - auto output = - rewriter.create(op->getLoc(), reduceOp.getResult(0)); + Value output = rewriter.create(op->getLoc(), reduceResult); if (keepDim) { auto outShapeInfo = @@ -848,22 +763,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), output, - outShapeTensor); - return success(); + output = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), output, *outShapeInfo, + getTypeConverter()->convertType(op.getType()), dims, + options.dimSizeIndexBits); } - rewriter.replaceOp(op, output.getResult()); + rewriter.replaceOp(op, output); return success(); } } // namespace @@ -920,13 +825,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( std::sort(dims.begin(), dims.end()); } - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputType.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputType.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputType.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { @@ -934,46 +834,27 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter); - if (!initValue) { - return failure(); - } - Value absValue = rewriter.create(op->getLoc(), input); Value powValue = rewriter.create(op->getLoc(), absValue, ord, nullptr); - auto reduceOp = rewriter.create( - op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), - powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = reduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, outElemType); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } + + auto scalarType = RankedTensorType::get({}, outElemType); auto constantOne = rewriter.create( - op->getLoc(), blockArgumentTy, + op->getLoc(), scalarType, DenseElementsAttr::get( - blockArgumentTy, + scalarType, APFloat(cast(outElemType).getFloatSemantics(), 1))); auto reciprocalOrd = rewriter.create( - op->getLoc(), blockArgumentTy, constantOne, ord); - auto output = rewriter.create( - op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr); + op->getLoc(), scalarType, constantOne, ord); + Value output = rewriter.create( + op->getLoc(), reduceResult, reciprocalOrd, nullptr); if (keepDim) { auto outShapeInfo = @@ -982,23 +863,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), output, - outShapeTensor); - return success(); + output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output, + *outShapeInfo, outType, dims, + options.dimSizeIndexBits); } - - rewriter.replaceOp(op, output.getResult()); + rewriter.replaceOp(op, output); return success(); } } // namespace @@ -1010,9 +879,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) - - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); @@ -1022,7 +888,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) - INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp); @@ -1031,12 +896,25 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp); #undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN -#define INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenOp) \ +#define INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context, \ - options) + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp); +#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAmaxOp); + INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAminOp); +#undef INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN - INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAmaxOp); - INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN(AtenAminOp); -#undef INSERT_ATEN_REDUCTION_KEEP_DIM_OP_PATTERN +#define INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMaxDimOp); + INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMinDimOp); +#undef INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bc99fde51b785..6ac3ae099a701 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -32,6 +32,7 @@ # unimplemented lowering torch -> linalg for torchvision.deform_conv2d # this is added to check the torch.onnx.export -> import_onnx -> torch path "DeformConv2D_basic", + "ReduceAnyDimFloatModule_basic", } LINALG_CRASHING_SET = { @@ -340,6 +341,7 @@ } FX_IMPORTER_XFAIL_SET = { + "ReduceAnyDimFloatModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -502,7 +504,6 @@ "ArgminIntModule_multiple_mins", "ArgminModule_basic", "ArgminModule_keepDim", - "ArgminModule_with_dim", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", @@ -716,10 +717,7 @@ "ReduceAllDimFloat_basic", "ReduceAllDimInt_basic", "ReduceMaxAlongDimUnsignedInt_basic", - "ReduceMinAlongDimNegative_basic", - "ReduceMinAlongDimSignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceMinAlongDim_basic", "ReduceMinKeepDimReturnBoth_basic", "ReduceMinKeepDim_basic", "ReduceProdDimIntFloatModule_basic", @@ -832,6 +830,11 @@ } STABLEHLO_PASS_SET = { + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDim_basic", + "ArgminModule_with_dim", + "ReduceMinAlongDimSignedInt_basic", + "ReduceAnyDimFloatModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", @@ -2198,6 +2201,7 @@ # Failure - cast error "PermuteNegativeIndexModule_basic", # Failure - incorrect numerics + "ReduceAnyDimFloatModule_basic", "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", "ElementwiseAtan2TensorIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 4891d6eaa1f04..347a1f8cc2571 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -239,6 +239,26 @@ def ReduceAnyFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) +class ReduceAnyDimFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.any(a, dim=0) + + +@register_test_case(module_factory=lambda: ReduceAnyDimFloatModule()) +def ReduceAnyDimFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + # ============================================================================== From 0e71a192d82fdfcfe5d3eb90882d9f07eca077ae Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 29 Jun 2024 21:44:05 +0800 Subject: [PATCH 26/30] [Torch] support decomposition of aten.aminmax (#3513) * unify decompisition of `aten.amax` and `aten.amin` * support `aten.amax` with `dim=()` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++ lib/Conversion/TorchToStablehlo/Reduction.cpp | 16 +- .../Transforms/AbstractInterpLibrary.cpp | 21 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 208 +++++++++--------- .../Transforms/LowerToBackendContract.cpp | 4 +- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 12 + .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 69 ++++++ 9 files changed, 254 insertions(+), 106 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f4223b1f4bf72..9428e749b5f96 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11463,6 +11463,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [ }]; } +def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchOptionalTensorType:$min, + AnyTorchOptionalTensorType:$max + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenAminmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index c9a2ad2e7ff8a..bc77a860adea0 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -488,14 +488,18 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } - for (auto d : inputDims) { - d = toPositiveDim(d, inputTy.getRank()); - // Drop invalid dims - if (isValidDim(d, inputTy.getRank())) { - dims.push_back(d); + if (inputDims.size() == 0) { + dims = llvm::to_vector(llvm::seq(0, inputTy.getRank())); + } else { + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } } + llvm::sort(dims.begin(), dims.end()); } - llvm::sort(dims.begin(), dims.end()); SmallVector reduceResultShape = getReduceOutputShape(inputTy.getShape(), dims); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f8a5409b8a707..8bf50fd21cc22 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7371,6 +7371,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.tuple, list>) {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %4 : !torch.tuple, list>\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %4 : !torch.tuple, list>\n" +" }\n" +" return %1 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -13568,6 +13584,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2086fb68afa25..36e79736381e4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -113,6 +113,25 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, .getValues(); } +// Reduction function to calculate min along given `dim`. +static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc, + Operation *op, Value input, Value dim, + bool keepDim) { + Value keepDimCst = rewriter.create(loc, keepDim); + BaseTensorType valueType = cast(computeReductionType( + rewriter, op, cast(input.getType()), dim, keepDim)); + if (!valueType) + return nullptr; + BaseTensorType indexType = + cast(valueType.getWithSizesAndDtype( + !valueType.hasSizes() ? std::optional>() + : llvm::ArrayRef(valueType.getSizes()), + IntegerType::get(op->getContext(), 64, IntegerType::Signed))); + return rewriter + .create(loc, valueType, indexType, input, dim, keepDimCst) + .getValues(); +} + // Helper for creating `aten::sub_tensor_op`. static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { @@ -605,65 +624,6 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter, return out; } -namespace { -/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the -/// number of dimensions across which the max needs to be computed. -/// Eg: -/// INPUT: -/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) -/// -/// OUTPUT: -/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 -/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 -/// final_output = aten.max.dim(input_2, 0, keepdim) #3 -/// -/// NOTE: We iterate over, in reverse order, every dimension included in `dim` -/// of the `aten.amax` op and create an `aten.amax.dim` op. -/// Input tensor to the next `aten.amax.dim` op is thus the output of the -/// previous `aten.amax.dim` op. -class DecomposeAtenAmaxOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAmaxOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) - - return rewriter.notifyMatchFailure(op, - "non-const dim parameter unsupported"); - - bool keepDim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) - return rewriter.notifyMatchFailure( - op, "Expected a constant boolean value for keepDim"); - - Value input = op.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy || !inputTy.hasSizes()) { - return rewriter.notifyMatchFailure(op, - "Expected input type having sizes"); - } - // For every dimension included in `dim` of the op, iterated over in - // reverse order, we create a call to aten.max.dim. - std::sort(dims.rbegin(), dims.rend()); - for (int64_t dimInt : dims) { - int64_t inputRank = inputTy.getSizes().size(); - dimInt = toPositiveDim(dimInt, inputRank); - if (!isValidDim(dimInt, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); - // The input to the next invocation of aten.max.dim is the output of the - // previous aten.max.dim op. - input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); - } - rewriter.replaceOp(op, input); - return success(); - } -}; -} // end namespace - namespace { class DecomposeAtenTriuOp : public OpRewritePattern { public: @@ -1880,52 +1840,69 @@ class DecomposeAten_LogSoftmaxBackwardDataOp } // namespace namespace { -class DecomposeAtenAMinMaxOp : public OpRewritePattern { +/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the +/// number of dimensions across which the max needs to be computed. +/// Eg: +/// INPUT: +/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) +/// +/// OUTPUT: +/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 +/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 +/// final_output = aten.max.dim(input_2, 0, keepdim) #3 +/// +/// NOTE: We iterate over, in reverse order, every dimension included in `dim` +/// of the `aten.amax` op and create an `aten.amax.dim` op. +/// Input tensor to the next `aten.amax.dim` op is thus the output of the +/// previous `aten.amax.dim` op. +template +class DecomposeAtenAminAmaxOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Torch::AtenAminOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - llvm::SmallVector dimList; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { - return rewriter.notifyMatchFailure(op, "dims not foldable constants"); + Location loc = op.getLoc(); + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for keepDim"); + + Value input = op.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "Expected input type having sizes"); } - bool keepdim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { - return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); + SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure(op, + "non-const dim parameter unsupported"); + if (dims.size() == 0) { + dims = llvm::to_vector(llvm::seq(0, inputTy.getSizes().size())); } - auto loc = op.getLoc(); - std::sort(dimList.begin(), dimList.end(), std::greater()); - - Value reduction = op.getSelf(); - auto resultTy = cast(op.getType()); - auto reductionTy = cast(reduction.getType()); - llvm::SmallVector reductionShape(reductionTy.getSizes()); - - for (auto dim : dimList) { - auto dimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim)); - reductionShape[dim] = 1; - if (!keepdim) { - for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) - reductionShape[i] = reductionShape[i + 1]; - reductionShape.resize(reductionShape.size() - 1); + // For every dimension included in `dim` of the op, iterated over in + // reverse order, we create a call to aten.max.dim. + std::sort(dims.rbegin(), dims.rend()); + for (int64_t dimInt : dims) { + int64_t inputRank = inputTy.getSizes().size(); + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimInt)); + // The input to the next invocation of aten.max.dim is the output of the + // previous aten.max.dim op. + static_assert(std::is_same_v || + std::is_same_v); + if (std::is_same_v) { + input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); + } else if (std::is_same_v) { + input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim); } - - reductionTy = rewriter.getType( - reductionShape, resultTy.getOptionalDtype()); - auto idxTy = rewriter.getType( - reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); - llvm::SmallVector types{reductionTy, idxTy}; - - reduction = rewriter - .create(loc, types, reduction, - dimValue, op.getKeepdim()) - .getResult(0); } - - rewriter.replaceOp(op, reduction); + rewriter.replaceOp(op, input); return success(); } }; @@ -1987,6 +1964,36 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { }; } // namespace +// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp` +namespace { +class DecomposeAtenAminmaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAminmaxOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Torch::ListType listType = + rewriter.getType(rewriter.getType()); + Value dimList; + if (isa(op.getDim().getType())) { + dimList = rewriter.create(loc, listType, + ArrayRef{}); + } else { + dimList = rewriter.create( + loc, listType, ArrayRef{op.getDim()}); + } + + auto amin = rewriter.create( + loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim()); + auto amax = rewriter.create( + loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim()); + rewriter.replaceOp(op, {amin, amax}); + return success(); + } +}; +} // namespace + // Decompose `aten.bucketize` into the following op sequence: // // def aten_bucketize(input, boundaries, out_int32, right): @@ -8598,7 +8605,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8631,10 +8637,15 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAminAmaxOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAminAmaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8707,7 +8718,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 15bebfc64390e..5e83c585ae8e4 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -438,6 +438,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -502,7 +505,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6ac3ae099a701..8272bc4b06918 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -830,6 +830,9 @@ } STABLEHLO_PASS_SET = { + "ReduceAminmaxSingleDim_basic", + "ReduceAminmaxAllDims_basic", + "ReduceAmaxEmptyDim_basic", "ReduceMinAlongDimNegative_basic", "ReduceMinAlongDim_basic", "ArgminModule_with_dim", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 9052c8cc2057a..6e4957e58898d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -722,6 +722,13 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: + if dim is None: + return [], [] + else: + reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) + return reduced_shape, reduced_shape + def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) @@ -4524,6 +4531,11 @@ def aten〇amin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇min〡dtype(self_rank_dtype), torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇aminmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c3cb95dd7fbeb..8e6745ea4a575 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -841,6 +841,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") + emit("aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)") emit( "aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 347a1f8cc2571..7cf6dd694458e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1204,6 +1204,29 @@ def ReduceAmaxMultiDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAmaxEmptyDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.amax(a, dim=()) + + +@register_test_case(module_factory=lambda: ReduceAmaxEmptyDim()) +def ReduceAmaxEmptyDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceAmaxOutOfOrderDim(torch.nn.Module): def __init__(self): super().__init__() @@ -1273,6 +1296,52 @@ def ReduceAminSingleDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAminmaxSingleDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.aminmax(a, dim=1) + + +@register_test_case(module_factory=lambda: ReduceAminmaxSingleDim()) +def ReduceAminmaxSingleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + +class ReduceAminmaxAllDims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.aminmax(a) + + +@register_test_case(module_factory=lambda: ReduceAminmaxAllDims()) +def ReduceAminmaxAllDims_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceMinFloatModule(torch.nn.Module): def __init__(self): super().__init__() From 2f231f394e39458df7eaa55c5af1d1929a6acd77 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 1 Jul 2024 22:15:45 +0530 Subject: [PATCH 27/30] Bump Onnx Version to 1.16.1 (#3515) This commit adds the support for new data types: uint4, and int4 and uint8 tensor protos. Also, it moves some tests from failing to crashing. Fixes https://github.com/llvm/torch-mlir/issues/3507 Signed-Off By: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 8 ++++---- python/torch_mlir/extras/onnx_importer.py | 5 +++++ python/torch_mlir/tools/import_onnx/__main__.py | 2 +- test-requirements.txt | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8272bc4b06918..adfb68b94be3a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2572,8 +2572,6 @@ "SplitDimStaticModule_basic", "SqrtIntConstantModule_basic", "SqrtIntModule_basic", - "StdCorrectionEmptyDimModule_basic", - "StdDimEmptyDimModule_basic", "SubFloatModule_basic", "SubIntModule_basic", "TanhBackward_basic", @@ -2627,8 +2625,6 @@ "UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2d_basic", - "VarCorrectionEmptyDimModule_basic", - "VarDimEmptyDimModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseModule_basic", "ViewDynamicExpandCollapseModule_basic", @@ -2797,6 +2793,10 @@ # Runtime crash: mismatched size for broadcast "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "StdDimEmptyDimModule_basic", + "StdCorrectionEmptyDimModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", } FX_IMPORTER_TOSA_XFAIL_SET = { diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index f8b10a2a46469..9fe29212386a7 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -1098,6 +1098,8 @@ def get_operator_function( onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), onnx.TensorProto.DataType.STRING: lambda: "!torch.str", + onnx.TensorProto.DataType.UINT4: lambda: IntegerType.get_unsigned(4), + onnx.TensorProto.DataType.INT4: lambda: IntegerType.get_signed(4), # Ommitted: STRING, } @@ -1134,6 +1136,9 @@ def get_operator_function( ), signless=False, ), + onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False + ), onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get( np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False ), diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index d20c212d0ede1..fa0e2a89dbba2 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -84,7 +84,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file) else: raw_model = onnx.load(args.input_file, load_external_data=False) - onnx.load_external_data_for_model(raw_model, args.data_dir) + onnx.load_external_data_for_model(raw_model, str(args.data_dir)) if args.opset_version: raw_model = onnx.version_converter.convert_version( diff --git a/test-requirements.txt b/test-requirements.txt index b21e8dfcd0213..42278b3cbcf6c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,5 @@ pillow dill multiprocess -onnx==1.15.0 +onnx==1.16.1 mpmath==1.3.0 From e2fbded49cdfa37185e8dbfbef0164e23d005c08 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 2 Jul 2024 09:08:57 +0800 Subject: [PATCH 28/30] =?UTF-8?q?[Torch=20Dialect]=20improve=20argmax/argm?= =?UTF-8?q?in's=20decomposition=20to=20support=20keep=E2=80=A6=20(#3514)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …dim=True when dim=None --- .../Transforms/AbstractInterpLibrary.cpp | 46 +++++++++----- .../Torch/Transforms/DecomposeComplexOps.cpp | 60 ++++++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 32 +++++++--- .../test_suite/reduction.py | 23 +++++++ 5 files changed, 126 insertions(+), 36 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 8bf50fd21cc22..0e244e51a96d7 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7313,11 +7313,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @__torch__.patched_argmax_shape_func(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %arg2 : !torch.bool\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %5 = torch.aten.append.t %3, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" @@ -7372,19 +7399,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple, list> {\n" -" %none = torch.constant.none\n" -" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.tuple, list>) {\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" %3 = torch.prim.ListConstruct : () -> !torch.list\n" -" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %4 : !torch.tuple, list>\n" -" } else {\n" -" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" -" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %4 : !torch.tuple, list>\n" -" }\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %1 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 36e79736381e4..f966b320c1323 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1920,15 +1920,19 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value dim = op.getDim(); - Value keepDim = op.getKeepdim(); Value result = op.getResult(); + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure( + op, "expected keepdim to be a constant bool"); + } BaseTensorType inputType = cast(input.getType()); BaseTensorType indicesTensorType = cast(result.getType()); std::optional maybeInputRank = getTensorRank(input); - if (!maybeInputRank) { + if (!maybeInputRank || *maybeInputRank == 0) { return rewriter.notifyMatchFailure( - op, "expected input tensor to have a rank"); + op, "expected input tensor to have a rank > 0"); } unsigned inputRank = *maybeInputRank; if (!indicesTensorType.hasSizes()) @@ -1945,21 +1949,49 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { BaseTensorType flattenType = cast(inputType.getWithSizesAndDtype( {kUnknownSize}, inputType.getOptionalDtype())); - dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value falseValue = rewriter.create(loc, false); input = rewriter.create(loc, flattenType, input, - dim, end); + zero, end); + Value resultIndices = + rewriter + .create( + loc, + valueTensorType.getWithSizesAndDtype( + ArrayRef{}, valueTensorType.getOptionalDtype()), + indicesTensorType.getWithSizesAndDtype( + ArrayRef{}, + indicesTensorType.getOptionalDtype()), + input, /*dim=*/zero, /*keepdim=*/falseValue) + .getIndices(); + if (keepDim) { + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value dimList = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + SmallVector(inputRank, one)); + resultIndices = rewriter.create( + loc, + indicesTensorType.getWithSizesAndDtype( + SmallVector(inputRank, 1), + indicesTensorType.getOptionalDtype()), + resultIndices, dimList); + } + rewriter.replaceOp(op, resultIndices); + return success(); + } else { + Value resultIndices = + rewriter + .create(loc, valueTensorType, indicesTensorType, + input, dim, op.getKeepdim()) + .getIndices(); + rewriter.replaceOp(op, resultIndices); + return success(); } - - Value resultArg = - rewriter - .create(loc, valueTensorType, indicesTensorType, input, - dim, keepDim) - .getIndices(); - - rewriter.replaceOp(op, resultArg); - return success(); } }; } // namespace diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index adfb68b94be3a..7bbd82a0d7c91 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1505,6 +1505,7 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ArgmaxKeepdimModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6e4957e58898d..1dbadd6897b55 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -680,8 +680,19 @@ def aten〇trace〡shape(self: List[int]) -> List[int]: assert len(self) == 2, "input must have rank 2" return [] +# TODO: replace this patched function with `upstream_shape_functions.argmax` when upstream fix it +# see https://github.com/pytorch/pytorch/pull/129838 +def patched_argmax_shape_func(self: List[int], dim: Optional[int] = None, keepdim: bool = False): + if dim is None and keepdim: + out: List[int] = [] + for i in self: + out.append(1) + return out + return upstream_shape_functions.argmax(self, dim, keepdim) + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`. Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`. @@ -690,11 +701,11 @@ def aten〇trace〡shape(self: List[int]) -> List[int]: ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. ]) def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: - return upstream_shape_functions.argmax(self, dim, keepdim) + return patched_argmax_shape_func(self, dim, keepdim) def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: # There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here. - return upstream_shape_functions.argmax(self, dim, keepdim) + return patched_argmax_shape_func(self, dim, keepdim) # TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, # making it impossible to add support for it using the current design of the shape library. @@ -722,12 +733,19 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`. + Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. + Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`. + Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`. + Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`. + ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds. + ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. +]) def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: - if dim is None: - return [], [] - else: - reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) - return reduced_shape, reduced_shape + reduced_shape = patched_argmax_shape_func(self, dim, keepdim) + return reduced_shape, reduced_shape def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 7cf6dd694458e..9a683e3c6219b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1533,6 +1533,29 @@ def ArgmaxModule_basic(module, tu: TestUtils): # ============================================================================== +class ArgmaxKeepdimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.argmax(a, keepdim=True) + + +@register_test_case(module_factory=lambda: ArgmaxKeepdimModule()) +def ArgmaxKeepdimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ArgmaxIntModule(torch.nn.Module): def __init__(self): super().__init__() From f1e3701cafe827e242cddc11124b9b222c716e3c Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 2 Jul 2024 15:31:06 +0800 Subject: [PATCH 29/30] [Stablehlo] fix compareOp with scalar's lowering (#3518) * use lhs tensor's element type as compute type when rhs is scalar. * previously `a != 1.0`(a is a fp32 tensor) will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf64>, tensor<2x5xf64>) -> tensor<2x5xi1>` * now it will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf32>, tensor<2x5xf32>) -> tensor<2x5xi1>` --- lib/Conversion/TorchToStablehlo/Basic.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 4d75979027cf5..644d28cc0974b 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -517,6 +517,8 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); + // use lhs's element type as compute type + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); rhsTy = dyn_cast(rhs.getType()); } From ca0e9066755b35c0889c6ab792265b0886325f50 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 2 Jul 2024 09:06:20 -0700 Subject: [PATCH 30/30] Fix `uint64_t` type. (#3519) `u_int64_t` is nonstandard and does not exist in MSVC. --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f966b320c1323..24a79cb0d3125 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2456,8 +2456,8 @@ class DecomposeAtenRenormOp : public OpRewritePattern { // Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... , // ndim-1] llvm::SmallVector reduceDimsVector; - for (u_int64_t i = 0; i < ndim; i++) { - if (i == (u_int64_t)dimInt) + for (uint64_t i = 0; i < ndim; i++) { + if (i == (uint64_t)dimInt) continue; Value constI = rewriter.create( @@ -2473,8 +2473,8 @@ class DecomposeAtenRenormOp : public OpRewritePattern { // Make output shape for linalg.vector_norm operation SmallVector inputSizeValue; - for (u_int64_t i = 0; i < inputSize.size(); i++) { - if (i != (u_int64_t)dimInt) + for (uint64_t i = 0; i < inputSize.size(); i++) { + if (i != (uint64_t)dimInt) inputSize[i] = 1; inputSizeValue.push_back(