Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
BladeDISC related patches
Browse files Browse the repository at this point in the history
* Rewrite mhlo with stablehlo after rebase.
* Fix BAZEL building error of multiple definition.
* Fix float width
* Fix divide_floor & export promoteTypes api (#9)
* To comply with the old pytorch versions
* Add native_dropout_backward & native_layer_norm_backward decomposition (#15)
* Add native_dropout and related ops pattern (llvm#1211)
* [MHLO] fix dot general contract
* Fix batch_norm, div.Tensor_mode and folder (#21)
* Reimplement linear lowering
* Reimplement 2-D rhs for mutmul
* Add torchdynamo
* Decompose torch.slice_scatter (llvm#1622)
* Fix i64 torch.tensor dtype
* Add more mhlo basic converters
* Alleviate softmax datatype check (#24)
* Fix decompose native_batch_norm (#27)
* Support group_norm lowering (#25)
* Decompose torch.ones/zeros (#28)
* Fix softmax output type
* Fix gather
* Fix some decompose patterns
* Not check assert at runtime (#31)
* Fix bool tensor attr conversion bug (#32)
* Fix mlirDenseElementsAttrBoolGet

---------

Co-authored-by: ZHENG, Zhen <[email protected]>
  • Loading branch information
Tanyo Kwok and JamesTheZ committed Jul 25, 2023
1 parent 0caaf8d commit f63f51a
Show file tree
Hide file tree
Showing 26 changed files with 1,117 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType);

Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType);
Expand Down
96 changes: 78 additions & 18 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4929,6 +4929,65 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
Torch_IntType:$num_groups,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -7233,9 +7292,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
}

def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
Pure,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
ReadOnly,
]> {
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
Expand Down Expand Up @@ -7742,53 +7802,53 @@ def Torch_AtenMaxOp : Torch_Op<"aten.max", [
}];
}

def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
let summary = "Generated op for `aten::amax : (Tensor, int[]?, bool) -> Tensor`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$values,
AnyTorchTensorType:$indices
AnyTorchTensorType:$results
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaxDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
void AtenAmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::amax : (Tensor, int[], bool) -> (Tensor)`";
let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
Torch_IntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
AnyTorchTensorType:$values,
AnyTorchTensorType:$indices
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void AtenAmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
void AtenMaxDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}
Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ enum ReductionType { MAX, MEAN, MIN, SUM, PROD };

ReductionType get_reduction_enum(const llvm::StringRef &reduce);

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -80,6 +82,7 @@ def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
Expand All @@ -98,6 +101,7 @@ def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
Expand Down
9 changes: 8 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,20 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);

target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
// target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
// patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
// context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
48 changes: 24 additions & 24 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;
using namespace mlir::torch::TorchConversion;

LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
Expand Down Expand Up @@ -148,7 +149,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
self = hlo::promoteType(rewriter, self, outType);
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success();
}
Expand Down Expand Up @@ -253,8 +254,8 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern<AtenOpT> {
->convertType(op.getType())
.template cast<TensorType>();

lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);

rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
Expand Down Expand Up @@ -300,8 +301,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
}
}

lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);

if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
Expand Down Expand Up @@ -354,8 +355,8 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
Expand Down Expand Up @@ -427,7 +428,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
}

// TODO: what is the PyTorch default type promotion?
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);

chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr;
Expand Down Expand Up @@ -494,8 +495,10 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
Value lhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType);
Value rhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);

DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
Expand Down Expand Up @@ -610,8 +613,8 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
// promote self and other types
self = hlo::promoteType(rewriter, self, outType);
other = hlo::promoteType(rewriter, other, outType);
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);

if (failed(
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
Expand Down Expand Up @@ -807,8 +810,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);
Expand Down Expand Up @@ -1212,7 +1215,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(

// Promote type
for (auto &v : builtinTensors) {
v = hlo::promoteType(rewriter, v, outType);
v = hlo::promoteType(rewriter, op.getLoc(), v, outType);
}

rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
Expand Down Expand Up @@ -1404,8 +1407,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
auto outTy =
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();

lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);

rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
Expand Down Expand Up @@ -1547,12 +1550,10 @@ class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool condition;
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: condition must be a constant");
}
if (!condition) {
return op->emitError("condition must be true");
if (matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) {
if (!condition) {
return op->emitError("condition must be true");
}
}
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -1679,7 +1680,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
Expand Down
Loading

0 comments on commit f63f51a

Please sign in to comment.