Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] support non-constant torch scalar in BasicOps #1134

Merged
merged 1 commit into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 13 additions & 29 deletions lib/Conversion/TorchToMhlo/BasicOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,15 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
}

if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
outElemTy, {})))
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
}

lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);

if (!skipMultiplyAlpha(op.alpha())) {
Value alpha;
if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(),
op.alpha(), alpha, outElemTy, {},
/*checkForUnity=*/false))) {
return op.emitError("currently only scalar constants are supported for "
"alpha in conversion to MHLO operation");
}
Value alpha =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
Expand Down Expand Up @@ -216,13 +208,13 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
outElemTy, {})))
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
}

Value lhsTensor = lhs;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
Expand Down Expand Up @@ -263,11 +255,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
}

if (!rhsTy) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
lhsElemTy, {}))) {
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
}
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy);
}

// TODO: what is the PyTorch default type promotion?
Expand Down Expand Up @@ -569,12 +557,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
.cast<RankedTensorType>();
auto outputShape = outputType.getShape();
auto outputElemType = outputType.getElementType();
Value mhloTensor;
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor,
outputElemType, outputShape,
false))) {
return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO");
}
Value mhloTensor =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
rewriter.replaceOp(op, mhloTensor);
return success();
}
Expand Down Expand Up @@ -1020,4 +1004,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
#undef INSERT_ATENOP_PATTERN
}
}
98 changes: 10 additions & 88 deletions lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,93 +174,15 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
return const_op.getResult();
}

// TODO: Support for variable scalar.
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &mhloTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool doBroadcast) {
// Retrieve a const float or int value but create the out Tensor with dtype.
double doubleValue;
auto isFloat =
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));

int64_t intValue;
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));

if (!isFloat && !isInt)
return op->emitError("Unable to extract the scalar constant");

if (dtype.isa<mlir::FloatType>()) {
if (doBroadcast) {
mhloTensor = getSplatConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
} else {
mhloTensor = mhlo::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
.getValue();
}
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
auto w = intType.getWidth();
if (w != 32 && w != 64)
return op->emitError("Unsupported integer type") << intType;

if (w == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return op->emitError("Supplied value of scalar constant exceeds limits "
"of destination type");
}
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue);
if (doBroadcast) {
mhloTensor =
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
}
} else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return op->emitError("Supplied value of scalar constant exceeds limits "
"of destination type");
}
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
if (doBroadcast) {
mhloTensor =
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
}
}
} else
return op->emitError("Usupported element type");

return success();
}

LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool checkForUnity) {
if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor,
dtype, dshape)))
return success();

// `alpha` has not been specified.
int64_t alphaValue;
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
return op->emitError("Currently only scalar constants are supported for "
"alpha in MHLO operation");
// When no alpha has been specified, this must be 1.
if (checkForUnity && alphaValue != 1)
return op->emitError("Unsupported integer value for alpha");

alphaTensor =
mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue);

return success();
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype) {
auto tensor = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ArrayRef<Value>{scalarValue});
auto dtype_tensor =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<mhlo::ReshapeOp>(
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
dtype_tensor);
}

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
Expand Down Expand Up @@ -439,4 +361,4 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
.getResult();
}
} // namespace mhlo
} // namespace mlir
} // namespace mlir
13 changes: 2 additions & 11 deletions lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,8 @@ template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);

LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &mhloTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool doBroadcast = true);

LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape,
bool checkForUnity);
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);

Expand Down
16 changes: 10 additions & 6 deletions test/Conversion/TorchToMhlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {

// -----

// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic(
// CHECK-SAME: ) -> !torch.vtensor<[],si64> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>
Expand Down Expand Up @@ -251,4 +255,4 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
%2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list<int>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32>
return %result0 : !torch.vtensor<[3,7,4,5],f32>
}
}
Loading