Skip to content

Commit

Permalink
Revert "add native_dropout and related ops pattern (#1211)"
Browse files Browse the repository at this point in the history
This reverts commit c935795.
  • Loading branch information
Yancey1989 authored Aug 16, 2022
1 parent 9d6ee48 commit 04e3730
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 207 deletions.
3 changes: 1 addition & 2 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5313,10 +5313,9 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
}

def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
NoSideEffect,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly,
ReadOnly
]> {
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
Expand Down
103 changes: 0 additions & 103 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,6 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
};
} // namespace

// ConvertAtenUnaryConvertOp legalize genearl unary ops into Mhlo ConverOp
namespace {
template <typename AtenOpT>
class ConvertAtenUnaryConvertOp: public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
};
} // namespace

// aten.ones & aten.zeros
// Ref: Error checking based on the Torch to TOSA lowering
namespace {
Expand Down Expand Up @@ -326,9 +307,6 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
std::is_same<AtenOpT, AtenGtScalarOp>()) {
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
op->getContext(), mhlo::ComparisonDirection::GT);
} else if (std::is_same<AtenOpT, AtenGeScalarOp>()) {
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
op->getContext(), mhlo::ComparisonDirection::GE);
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
std::is_same<AtenOpT, AtenEqScalarOp>()) {
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
Expand Down Expand Up @@ -1002,75 +980,6 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
}
} // namespace

// AtenSizeIntOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
AtenSizeIntOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
auto dim = rewriter.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getIndexType(), adaptor.dim());
auto dimSize = rewriter.create<tensor::DimOp>(
op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim);

rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, getTypeConverter()->convertType(op.getType()), dimSize);

return success();
}
} // namespace

// ValsemVariantAtenUniformOp
namespace {
template <>
LogicalResult ConvertAtenOp<ValsemVariantAtenUniformOp>::matchAndRewrite(
ValsemVariantAtenUniformOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
auto inputTy = adaptor.self().getType().template cast<RankedTensorType>();
auto loc = op.getLoc();
if (!inputTy) {
op.emitError("input should be ranked tensor type.");
}
auto definingOp = op.self().getDefiningOp();
auto shape = definingOp->getOperand(0);
SmallVector<Value, 4> dimSizes;
getListConstructElements(shape, dimSizes);
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
dSize = rewriter.create<torch::TorchConversion::ToI64Op>(loc, dSize).getResult();
return dSize;
});

auto mhloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), dimSizes);

double fromDoubleValue, toDoubleValue;
if (!matchPattern(op.from(), m_TorchConstantFloat(&fromDoubleValue))) {
op.emitError("operand #1 should be scalar");
}
if (!matchPattern(op.to(), m_TorchConstantFloat(&toDoubleValue))) {
op.emitError("operand #2 should be scalar");
}
Value fromTensor = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(inputTy.getElementType(), fromDoubleValue));
Value toTensor = rewriter.create<mhlo::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue));

auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
rewriter.replaceOpWithNewOp<mhlo::RngOp>(
op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM);
return success();
}
}
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand All @@ -1096,15 +1005,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
#undef INSERT_UNARY_FPONLY_PATTERN

#define INSERT_UNARY_CONVERT_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryConvertOp<AtenOp>>(typeConverter, \
context);
INSERT_UNARY_CONVERT_PATTERN(AtenContiguousOp);
INSERT_UNARY_CONVERT_PATTERN(AtenToDtypeOp);
INSERT_UNARY_CONVERT_PATTERN(AtenTypeAsOp);
#undef INSERT_UNARY_CONVERT_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
Expand Down Expand Up @@ -1138,7 +1038,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(

INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp);
Expand All @@ -1164,7 +1063,5 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(

INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
INSERT_ATENOP_PATTERN(ValsemVariantAtenUniformOp);
#undef INSERT_ATENOP_PATTERN
}
43 changes: 0 additions & 43 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,47 +1155,6 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
};
} // namespace

namespace {
class DecomposeAtenNativeDropoutOp : public OpRewritePattern<AtenNativeDropoutOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeDropoutOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value input = op.input();
Value prob = op.p();
bool train = false;
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
return rewriter.notifyMatchFailure(op, "train must be a boolean constant");

BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!train) {
// TODO(yancey.yx): supports inference mode
return op.emitError(
"native_dropout does not support argument train is false");
}
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only support floating type input for training mode");
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value floatOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
Value maskedInput =
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
Value output =
rewriter.create<AtenMulScalarOp>(loc, inputType, maskedInput, oneMinusP);
Value one =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
boolMask = rewriter.create<AtenGeScalarOp>(
loc, op.getResult(1).getType(), boolMask, one);
rewriter.replaceOp(op, {output, boolMask});
return success();
}
};
} // namespace
// Decompose aten.var into: aten.var.dim op.
namespace {
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
Expand Down Expand Up @@ -2635,8 +2594,6 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAten_ToCopyOp>(context);
target.addIllegalOp<Aten_ToCopyOp>();
patterns.add<DecomposeAtenDropoutOp>(context);
patterns.add<DecomposeAtenNativeDropoutOp>(context);
target.addIllegalOp<AtenNativeDropoutOp>();
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
patterns.add<DecomposeAtenNewEmptyOp>(context);
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());

pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());

if (options.optimize) {
// Clean up any non-canonical code introduced above..
Expand Down
47 changes: 0 additions & 47 deletions test/Conversion/TorchToMhlo/dropout.mlir

This file was deleted.

12 changes: 2 additions & 10 deletions test/Conversion/TorchToMhlo/view_like.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,17 +360,9 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32>
// CHECK: %[[INTneg1:.*]] = torch.constant.int -1
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]]
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32>
// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]]
// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32>
// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]]
// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]]
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]
Expand Down

0 comments on commit 04e3730

Please sign in to comment.