From 4f5f577d571dfca130fb409e8957522eb302dd80 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 15 Aug 2022 09:28:47 +0800 Subject: [PATCH] add native_dropout and related ops pattern (#1211) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 3 +- lib/Conversion/TorchToMhlo/Basic.cpp | 5 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 44 +++++++++++++++++ .../TorchConversion/Transforms/Passes.cpp | 2 + test/Conversion/TorchToMhlo/dropout.mlir | 47 +++++++++++++++++++ test/Conversion/TorchToMhlo/view_like.mlir | 12 ++++- 6 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 test/Conversion/TorchToMhlo/dropout.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ba0062286456..e36b40e0f55c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5516,9 +5516,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ } def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ + NoSideEffect, AllowsTypeRefinement, HasValueSemantics, - ReadOnly + ReadOnly, ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 07f041b20e47..362c8c1be946 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -312,6 +312,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GT); + } else if (std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::GE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( @@ -1023,7 +1026,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } } // namespace - void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1083,6 +1085,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6da22ae6cdec..a77a1b9cfa60 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1332,6 +1332,48 @@ class DecomposeAtenNativeDropoutBackwardOp }; } // namespace +namespace { +class DecomposeAtenNativeDropoutOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value input = op.input(); + Value prob = op.p(); + bool train = false; + if (!matchPattern(op.train(), m_TorchConstantBool(&train))) + return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); + + BaseTensorType inputType = input.getType().cast(); + if (!train) { + // TODO(yancey.yx): supports inference mode + return op.emitError( + "native_dropout does not support argument train is false"); + } + if (!inputType.hasDtype() || !inputType.getDtype().isa()) + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + Value noneVal = rewriter.create(loc); + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = + rewriter.create(loc, inputType, maskedInput, oneMinusP); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + boolMask = rewriter.create( + loc, op.getResult(1).getType(), boolMask, one); + rewriter.replaceOp(op, {output, boolMask}); + return success(); + } +}; +} // namespace + // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -2977,6 +3019,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + patterns.add(context); + target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index b2f012f03b2f..6e1d68a76836 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -123,6 +123,8 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( void TorchConversion::createTorchBackendToMhloBackendPipeline( OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { pm.addNestedPass(createConvertTorchToMhloPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + pm.addNestedPass(createConvertTorchToArithPass()); // Clean up any non-canonical code introduced above.. pm.addNestedPass(createCanonicalizerPass()); diff --git a/test/Conversion/TorchToMhlo/dropout.mlir b/test/Conversion/TorchToMhlo/dropout.mlir new file mode 100644 index 000000000000..b61a61b3bf83 --- /dev/null +++ b/test/Conversion/TorchToMhlo/dropout.mlir @@ -0,0 +1,47 @@ +// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.native_dropout.train( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: f64) -> (tensor, tensor) { +// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[CST_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST_1:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64 +// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64 +// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor) -> tensor +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor +// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64 +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor +// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor -> tensor<2xindex> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor) -> tensor +// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor -> tensor<2xindex> +// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> +// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> +// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor) { +// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor +// CHECK: shape.assuming_yield %[[T19]] : tensor +// CHECK: } +// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor -> tensor<2xindex> +// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor +// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor, tensor) -> tensor +// CHECK: return %[[T24]], %[[T26]] : tensor, tensor +func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) { + %bool_true = torch.constant.bool true + %result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> + return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> +} \ No newline at end of file diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 1c878a5a555d..41d84c76208e 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -354,9 +354,17 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> // CHECK: %[[INTneg1:.*]] = torch.constant.int -1 // CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]] // CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32> +// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]] +// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32> +// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64 +// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]] // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]