Skip to content

Commit

Permalink
[mlir][tosa] Rework tosa.apply_scale lowering for 32-bit
Browse files Browse the repository at this point in the history
Added handling rounding behavior in 32-bits for when possible. This
avoids kernel compilation generating scalarized code on platforms where
64-bit vectors are not available.

As the 48-bit lowering requires 64-bit anyway, we added a full 64-bit
solution simplifying the old path.

Reviewed By: dcaballe, mravishankar

Differential Revision: https://reviews.llvm.org/D125583
  • Loading branch information
rsuderman committed May 17, 2022
1 parent d4545e6 commit 9294a1e
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 170 deletions.
5 changes: 4 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,10 @@ def TosaToArith : Pass<"tosa-to-arith"> {
let options = [
Option<"includeApplyRescale", "include-apply-rescale",
"bool", /*default=*/"false",
"Whether to include the lowering for tosa.apply_rescale to arith">
"Whether to include the lowering for tosa.apply_rescale to arith">,
Option<"use32Bit", "use-32-bit",
"bool", /*default=*/"false",
"Whether to prioritze lowering to 32-bit operations">
];

let constructor = "tosa::createTosaToArith()";
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ std::unique_ptr<Pass> createTosaToArith();

void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);

void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns);
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns,
bool include32Bit = false);

} // namespace tosa
} // namespace mlir
Expand Down
243 changes: 169 additions & 74 deletions mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
Expand Down Expand Up @@ -49,103 +50,194 @@ Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
return rewriter.getIntegerAttr(type, value);
}

Value getConstantValue(Location loc, Type type, int64_t value,
PatternRewriter &rewriter) {
return rewriter.create<arith::ConstantOp>(
loc, getConstantAttr(type, value, rewriter));
}

// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
// using 64-bit operations to perform the necessary multiply, bias, and shift.
// Multiple types are used to use minimal bit width operations.
class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
class ApplyScaleGenericOpConverter
: public OpRewritePattern<tosa::ApplyScaleOp> {
public:
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value value32 = op.value();
Value value = op.value();
Value multiplier32 = op.multiplier();
Value shift8 = op.shift();

bool doubleRound = op.double_round();
Type inType = op.value().getType();
Type resultTy = op.getType();

Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy);
Type valueTy = value.getType();
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);

Value one8 = rewriter.create<arith::ConstantOp>(
loc, getConstantAttr(i8Ty, 1, rewriter));
Value one64 = rewriter.create<arith::ConstantOp>(
loc, getConstantAttr(i64Ty, 1, rewriter));

Value shiftSubOne8 = rewriter.create<arith::SubIOp>(loc, shift8, one8);

// The rounding value semantics below equate to the following code:
// int64_t round = 1 << (shift - 1);
// if (double_round) {
// if (shift > 31 && value >= 0) round += 1<<30;
// if (shift > 31 && value < 0) round -= 1<<30;
// }
//
// Note that minimal bitwidth operators are used throughout the block.

Value round64 = rewriter.create<arith::ShLIOp>(
loc, one64, rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftSubOne8));

// Double rounding is performing a round operation before the shift
if (doubleRound) {
Value one32 = rewriter.create<arith::ConstantOp>(
loc, getConstantAttr(i32Ty, 1, rewriter));
Value shift32 = rewriter.create<arith::ExtSIOp>(loc, i32Ty, shift8);
Value thirty32 = rewriter.create<arith::ConstantOp>(
loc, getConstantAttr(i32Ty, 30, rewriter));

Value shiftThirty32 =
rewriter.create<arith::ShLIOp>(loc, one32, thirty32);
Value shiftThirty64 =
rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftThirty32);

// Round value needs to with be added or subtracted depending on the sign
// of the input value.
Value roundAdd64 =
rewriter.create<arith::AddIOp>(loc, round64, shiftThirty64);
Value roundSub64 =
rewriter.create<arith::SubIOp>(loc, round64, shiftThirty64);

Value zero32 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(inType));
Value valueGreaterThanZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value32, zero32);
Value zero = getConstantValue(loc, valueTy, 0, rewriter);
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);

Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.shift());

// Compute the multiplication in 64-bits then select the high / low parts.
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
Value multiplier64 =
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
Value multiply64 =
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);

Value doubleRound64 = rewriter.create<arith::SelectOp>(
loc, valueGreaterThanZero, roundAdd64, roundSub64);
// Apply normal rounding.
Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);

// We only perform double rounding if the shift value is greater than 32.
Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
loc, getConstantAttr(i32Ty, 32, rewriter));
Value shiftGreaterThanThirtyTwo = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
round64 = rewriter.create<arith::SelectOp>(loc, shiftGreaterThanThirtyTwo,
doubleRound64, round64);
// Apply double rounding if necessary.
if (op.double_round()) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
Value positive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value, zero);
Value dir =
rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
Value valid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
multiply64 =
rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
}

// The computation below equates to the following pseudocode:
// int64_t result = (int64_t)value * multiplier + round;
// result = result >> shift;
//
// Note that multiply and shift need to be perform in i64 to preserve bits.
Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);

rewriter.replaceOp(op, result32);
return success();
}
};

class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
public:
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();

Type resultTy = op.getType();
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);

Value value = op.value();
if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
return failure();
}

Value value32 = op.value();
Value multiplier32 = op.multiplier();
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.shift());

// Constants used during the scaling operation.
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter);

// Compute the multiplication in 64-bits then select the high / low parts.
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
Value multiplier64 =
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
Value shift64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, shift8);
Value multiply64 =
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);

// Multiply as a pair of i64 values to guarantee the end value fits.
Value result64 = rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
result64 = rewriter.create<arith::AddIOp>(loc, result64, round64);
result64 = rewriter.create<arith::ShRSIOp>(loc, result64, shift64);
// Grab out the high/low of the computation
Value high64 =
rewriter.create<arith::ShRUIOp>(loc, multiply64, thirtyTwo64);
Value high32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, high64);
Value low32 = rewriter.create<arith::MulIOp>(loc, value32, multiplier32);

Value result32 = rewriter.create<arith::TruncIOp>(loc, resultTy, result64);
// Determine the direction and amount to shift the high bits.
Value shiftOver32 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
Value roundHighBits = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);

rewriter.replaceOp(op, result32);
Value shiftHighL =
rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
Value shiftHighR =
rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);

shiftHighL =
rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
shiftHighR =
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);

// Conditionally perform our double round.
if (op.double_round()) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value32, zero32);

Value roundDir =
rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
roundDir =
rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);

Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);

Value shiftRound =
rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);

low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
}

// Conditionally apply rounding in the low bits.
{
Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
roundBit);

Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
Value wasRounded = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ugt, low32, newLow32);
low32 = newLow32;

Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
}

// Conditionally apply rounding in the high bits.
{
Value shiftSubOne =
rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
zero32);
high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
}

// Combine the correct high/low bits into the final rescale result.
high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);

// Apply the rounding behavior and shift to the final alignment.
Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);

// Truncate if necessary.
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
}

rewriter.replaceOp(op, result);
return success();
}
};
Expand All @@ -158,6 +250,9 @@ void mlir::tosa::populateTosaToArithConversionPatterns(
}

void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<ApplyScaleOpConverter>(patterns->getContext());
RewritePatternSet *patterns, bool include32Bit) {
patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
if (include32Bit) {
patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
}
}
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct TosaToArith : public TosaToArithBase<TosaToArith> {
mlir::tosa::populateTosaToArithConversionPatterns(&patterns);

if (this->includeApplyRescale) {
mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns);
mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns,
this->use32Bit);
target.addIllegalOp<tosa::ApplyScaleOp>();
}

Expand Down
Loading

0 comments on commit 9294a1e

Please sign in to comment.