Skip to content

Commit

Permalink
Add native_dropout_backward & native_layer_norm_backward decomposition (
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok committed Oct 31, 2022
1 parent 8429920 commit aeffd16
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,25 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
};
} // namespace

// grad_output * mask * scale
namespace {
class DecomposeAtenNativeDropoutBackwardOp
: public OpRewritePattern<AtenNativeDropoutBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value maskedGradOutput = rewriter.create<AtenMulTensorOp>(
loc, op.getType(), op.grad_output(), op.mask());
rewriter.replaceOpWithNewOp<AtenMulScalarOp>(op, op.getType(),
maskedGradOutput, op.scale());
return success();
}
};
} // namespace

// Decompose aten.var into: aten.var.dim op.
namespace {
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
Expand Down Expand Up @@ -3087,6 +3106,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormBackwardOp>();
patterns.add<DecomposeAtenNativeLayerNormBackwardOp>(context);

target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
Expand Down Expand Up @@ -3160,6 +3181,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_ToCopyOp>();
patterns.add<DecomposeAtenDropoutOp>(context);
target.addIllegalOp<AtenDropoutOp>();
patterns.add<DecomposeAtenNativeDropoutBackwardOp>(context);
target.addIllegalOp<AtenNativeDropoutBackwardOp>();
target.addIllegalOp<AtenNewEmptyOp>();
patterns.add<DecomposeAtenNewEmptyOp>(context);
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
Expand Down

0 comments on commit aeffd16

Please sign in to comment.