Skip to content

Commit

Permalink
[mlir][math] Added algebraic simplification for IPowI operation.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D130390
  • Loading branch information
vzakhari committed Aug 15, 2022
1 parent 133624a commit 2dde4ba
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 1 deletion.
93 changes: 92 additions & 1 deletion mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,100 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
return failure();
}

//----------------------------------------------------------------------------//
// IPowIOp strength reduction.
//----------------------------------------------------------------------------//

namespace {
struct IPowIStrengthReduction : public OpRewritePattern<math::IPowIOp> {
unsigned exponentThreshold;

public:
IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: OpRewritePattern<math::IPowIOp>(context, benefit, generatedNames),
exponentThreshold(exponentThreshold) {}
LogicalResult matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult
IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value base = op.getLhs();

IntegerAttr scalarExponent;
DenseIntElementsAttr vectorExponent;

bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));

// Simplify cases with known exponent value.
int64_t exponentValue = 0;
if (isScalar)
exponentValue = scalarExponent.getInt();
else if (isVector && vectorExponent.isSplat())
exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
else
return failure();

// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = op.getType().dyn_cast<VectorType>())
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
return value;
};

if (exponentValue == 0) {
// Replace `ipowi(x, 0)` with `1`.
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
rewriter.replaceOp(op, bcast(one));
return success();
}

bool exponentIsNegative = false;
if (exponentValue < 0) {
exponentIsNegative = true;
exponentValue *= -1;
}

// Bail out if `abs(exponent)` exceeds the threshold.
if (exponentValue > exponentThreshold)
return failure();

// Inverse the base for negative exponent, i.e. for
// `ipowi(x, negative_exponent)` set `x` to `1 / x`.
if (exponentIsNegative) {
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
base = rewriter.create<arith::DivSIOp>(loc, bcast(one), base);
}

Value result = base;
// Transform to naive sequence of multiplications:
// * For positive exponent case replace:
// `ipowi(x, positive_exponent)`
// with:
// x * x * x * ...
// * For negative exponent case replace:
// `ipowi(x, negative_exponent)`
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
for (unsigned i = 1; i < exponentValue; ++i)
result = rewriter.create<arith::MulIOp>(loc, result, base);

rewriter.replaceOp(op, result);
return success();
}

//----------------------------------------------------------------------------//

void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
patterns.add<PowFStrengthReduction>(patterns.getContext());
patterns.add<PowFStrengthReduction, IPowIStrengthReduction>(
patterns.getContext());
}
90 changes: 90 additions & 0 deletions mlir/test/Dialect/Math/algebraic-simplification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,93 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}

// CHECK-LABEL: @ipowi_zero_exp(
// CHECK-SAME: %[[ARG0:.+]]: i32
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
// CHECK-SAME: -> (i32, vector<4xi32>) {
func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>) {
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
// CHECK: return %[[CST_S]], %[[CST_V]]
%c = arith.constant 0 : i32
%v = arith.constant dense <0> : vector<4xi32>
%0 = math.ipowi %arg0, %c : i32
%1 = math.ipowi %arg1, %v : vector<4xi32>
return %0, %1 : i32, vector<4xi32>
}

// CHECK-LABEL: @ipowi_exp_one(
// CHECK-SAME: %[[ARG0:.+]]: i32
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
// CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
// CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
// CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]]
%c1 = arith.constant 1 : i32
%v1 = arith.constant dense <1> : vector<4xi32>
%0 = math.ipowi %arg0, %c1 : i32
%1 = math.ipowi %arg1, %v1 : vector<4xi32>
%cm1 = arith.constant -1 : i32
%vm1 = arith.constant dense <-1> : vector<4xi32>
%2 = math.ipowi %arg0, %cm1 : i32
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
}

// CHECK-LABEL: @ipowi_exp_two(
// CHECK-SAME: %[[ARG0:.+]]: i32
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
// CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
// CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
%c1 = arith.constant 2 : i32
%v1 = arith.constant dense <2> : vector<4xi32>
%0 = math.ipowi %arg0, %c1 : i32
%1 = math.ipowi %arg1, %v1 : vector<4xi32>
%cm1 = arith.constant -2 : i32
%vm1 = arith.constant dense <-2> : vector<4xi32>
%2 = math.ipowi %arg0, %cm1 : i32
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
}

// CHECK-LABEL: @ipowi_exp_three(
// CHECK-SAME: %[[ARG0:.+]]: i32
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
// CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
// CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
// CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
// CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
// CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
// CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
// CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
// CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
%c1 = arith.constant 3 : i32
%v1 = arith.constant dense <3> : vector<4xi32>
%0 = math.ipowi %arg0, %c1 : i32
%1 = math.ipowi %arg1, %v1 : vector<4xi32>
%cm1 = arith.constant -3 : i32
%vm1 = arith.constant dense <-3> : vector<4xi32>
%2 = math.ipowi %arg0, %cm1 : i32
%3 = math.ipowi %arg1, %vm1 : vector<4xi32>
return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
}

0 comments on commit 2dde4ba

Please sign in to comment.