diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4057e9972ea4..647502f15663 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5951,9 +5951,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ } def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ + NoSideEffect, AllowsTypeRefinement, HasValueSemantics, - ReadOnly + ReadOnly, ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 42c520bbb962..980025d54a37 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -402,6 +402,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( op->getContext(), chlo::ComparisonDirection::GT); + } else if (std::is_same()) { + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::GE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = chlo::ComparisonDirectionAttr::get( @@ -1249,6 +1252,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c8eeffa2b70b..5b4c00f7e5d4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1357,6 +1357,48 @@ class DecomposeAtenNativeDropoutBackwardOp }; } // namespace +namespace { +class DecomposeAtenNativeDropoutOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value input = op.input(); + Value prob = op.p(); + bool train = false; + if (!matchPattern(op.train(), m_TorchConstantBool(&train))) + return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); + + BaseTensorType inputType = input.getType().cast(); + if (!train) { + // TODO(yancey.yx): supports inference mode + return op.emitError( + "native_dropout does not support argument train is false"); + } + if (!inputType.hasDtype() || !inputType.getDtype().isa()) + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + Value noneVal = rewriter.create(loc); + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = + rewriter.create(loc, inputType, maskedInput, oneMinusP); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + boolMask = rewriter.create( + loc, op.getResult(1).getType(), boolMask, one); + rewriter.replaceOp(op, {output, boolMask}); + return success(); + } +}; +} // namespace + // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -3180,6 +3222,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + patterns.add(context); + target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); diff --git a/test/Conversion/TorchToMhlo/dropout.mlir b/test/Conversion/TorchToMhlo/dropout.mlir new file mode 100644 index 000000000000..b61a61b3bf83 --- /dev/null +++ b/test/Conversion/TorchToMhlo/dropout.mlir @@ -0,0 +1,47 @@ +// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.native_dropout.train( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: f64) -> (tensor, tensor) { +// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[CST_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST_1:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64 +// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64 +// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor) -> tensor +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor +// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64 +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor +// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor -> tensor<2xindex> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor) -> tensor +// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor -> tensor<2xindex> +// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> +// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> +// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor) { +// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor +// CHECK: shape.assuming_yield %[[T19]] : tensor +// CHECK: } +// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor -> tensor<2xindex> +// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor +// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor, tensor) -> tensor +// CHECK: return %[[T24]], %[[T26]] : tensor, tensor +func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) { + %bool_true = torch.constant.bool true + %result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> + return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> +} \ No newline at end of file