diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 855b75c590eb..c8eeffa2b70b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1338,6 +1338,25 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { }; } // namespace +// grad_output * mask * scale +namespace { +class DecomposeAtenNativeDropoutBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value maskedGradOutput = rewriter.create( + loc, op.getType(), op.grad_output(), op.mask()); + rewriter.replaceOpWithNewOp(op, op.getType(), + maskedGradOutput, op.scale()); + return success(); + } +}; +} // namespace + // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -3087,6 +3106,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); target.addIllegalOp(); patterns.add(context); @@ -3160,6 +3181,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); patterns.add(context);