Skip to content

Commit

Permalink
[Arith] Simplification of ceil, log2, and left_shift (#11646)
Browse files Browse the repository at this point in the history
* [TIR] Simplify expressions using tir.ceil and tir.log2

These expressions are introduced in `topi.math.ceil_log2`, and can
otherwise be propagated through to the generated kernel.

* [Arith] Added left shift handling to ConstIntBoundsAnalyzer

Previously, only right shift was handled.  These left shifts are
used in the `cuda.sort` implementation.

* Update to avoid left shift of negative numbers

* Updated rewriting of log2(x) to only occur in ceil(log2(x))

Per @wrongtest's request, to avoid rounding differences between
different devices.

* Avoid assumptions made of negative arguments to left-shift

* Recognize bounds of int(ceil(log2(arg)))
  • Loading branch information
Lunderberg authored Jun 21, 2022
1 parent bc75487 commit bd800c9
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 2 deletions.
96 changes: 94 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,17 @@ class ConstIntBoundAnalyzer::Impl
}

Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry a;

// int(ceil(log2(cast(n,"float64")))) is used as the
// implementation of topi.math.ceil_log2, and appears in iteration
// bounds.
if (auto opt = FindCeilLog2Arg(op)) {
a = CeilLog2Bounds(opt.value());
} else {
a = VisitExpr(op->value);
}

Entry b = Everything(op->dtype);
return Intersect(a, b);
}
Expand Down Expand Up @@ -314,6 +324,8 @@ class ConstIntBoundAnalyzer::Impl

if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tir::builtin::shift_left())) {
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else {
Expand Down Expand Up @@ -341,6 +353,20 @@ class ConstIntBoundAnalyzer::Impl
}
}

Entry VisitLeftShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);

if (a.min_value < 0 || b.min_value < 0) {
// If either operand can negative, we may run into undefined
// behavior for some targets. In these cases, avoid making any
// assumptions about the result.
return Everything(op->dtype);
}

return BinaryOpBoundary(a, b, InfAwareLeftShift);
}

Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
Expand Down Expand Up @@ -509,7 +535,33 @@ class ConstIntBoundAnalyzer::Impl
return floordiv(x, y);
}
/*!
* \brief Compute x / y, aware of inf.
* \brief Compute x << y, aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
static int64_t InfAwareLeftShift(int64_t x, int64_t y) {
if (x == kPosInf || x == kNegInf) return x;

// Can be replaced with std::bit_width in C++20
auto bit_width = [](int64_t as_signed) {
uint64_t val = std::abs(as_signed);
int num_bits = 0;
while (val) {
++num_bits;
val >>= 1;
}
return num_bits;
};
int x_bits = bit_width(x);
if (x_bits + y < 64) {
return x << y;
} else {
return kPosInf;
}
}
/*!
* \brief Compute x >> y, aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
Expand Down Expand Up @@ -609,6 +661,46 @@ class ConstIntBoundAnalyzer::Impl
}
return {};
}

/*!
* \brief Extract the argument from int(ceil(log2(arg)))
*
* This expression is used as the implementation of
* topi.math.ceil_log2, and can appear in iteration bounds.
*/
static Optional<PrimExpr> FindCeilLog2Arg(const CastNode* op) {
if (op->dtype.is_int()) {
if (auto as_call = op->value.as<CallNode>()) {
if (as_call->op.same_as(Op::Get("tir.ceil"))) {
PrimExpr ceil_arg = as_call->args[0];
if (auto arg_call = ceil_arg.as<CallNode>()) {
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
PrimExpr log_arg = arg_call->args[0];
return log_arg;
}
}
}
}
}
return NullOpt;
}

/*! \brief Propagate constraints through ceil(log2(arg))
*
* Helper function for CastNode visitor
*/
Entry CeilLog2Bounds(PrimExpr arg) {
if (auto as_float = arg.as<FloatImmNode>()) {
// A cast from int to float may have already been simplified
// out. Normally we don't inspect floating-point arguments, but here we can
int64_t val = std::ceil(std::log2(as_float->value));
return MakeBound(val, val);
} else {
Entry arg_bounds = VisitExpr(arg);
return MakeBound(std::ceil(std::log2(arg_bounds.min_value)),
std::ceil(std::log2(arg_bounds.max_value)));
}
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {
Expand Down
21 changes: 21 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1640,13 +1640,34 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// the operator overload will eagerly constant fold.
return op->args[0] << op->args[1];
}
} else if (op->op.same_as(Op::Get("tir.ceil"))) {
PrimExpr ceil_arg = op->args[0];
if (auto arg_int = op->args[0].as<IntImmNode>()) {
return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value));
} else if (auto arg_float = ceil_arg.as<FloatImmNode>()) {
return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value)));
} else if (auto arg_call = ceil_arg.as<CallNode>()) {
// ceil(log2(cast(n,"float64"))) is used as the implementation of
// topi.math.ceil_log2, and appears in iteration bounds.
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
PrimExpr log_arg = arg_call->args[0];
if (auto as_float = log_arg.as<FloatImmNode>()) {
// ceil(log2(n)) can be simplified, and should produce the
// same integer result regardless of the target's rounding
// conventions.
return FloatImm(op->dtype, std::ceil(std::log2(as_float->value)));
}
}
}
}

if (op->op.same_as(tir::builtin::likely())) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (auto match = TryMatchLiteralConstraint(op->args[0])) {
return match.value();
}
}

return ret;
}

Expand Down
107 changes: 107 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,5 +391,112 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
A[i, j] = 2


class TestCeilLog2Int(BaseBeforeAfter):
"""Simplify expressions resulting from topi.math.ceil_log2"""

@T.prim_func
def before(A: T.Buffer[1, "int32"]):
A[0] = T.cast(
T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32"
)

@T.prim_func
def expected(A: T.Buffer[1, "int32"]):
A[0] = 4


class TestLeftCeilLog2LowerBound(BaseBeforeAfter):
"""Integer bounds are propagated through topi.math.ceil_log2"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
x = T.cast(
T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"), dtype="float64"), dtype="float64"),
dtype="int32",
)
if x == 11:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftLowerBound(BaseBeforeAfter):
"""Integer bounds are propagated through left shift
min(1 << i) = 1 << min(i)
= 1 << 0
= 1
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(1, i, dtype="int32") >= 1:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftUpperBound(BaseBeforeAfter):
"""Integer bounds are propagated through left shift
max(31 << i) = 31 << max(i)
= 31 << 15
= 1015808
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(31, i, dtype="int32") <= 1015808:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftOfNegativeValue(BaseBeforeAfter):
"""No const int bounds of left shift of negative value.
This is target dependent, and does not currently have a specified
behavior in TIR. For example, in CodeGenC, this generates C code
with undefined behavior.
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if -64 <= T.shift_left(-i, 4, dtype="int32"):
A[i] = 0.0

expected = before


class TestLeftShiftByNegativeValue(BaseBeforeAfter):
"""No const int bounds of left shift by negative bit count.
This is target dependent, and does not currently have a specified
behavior in TIR. For example, in CodeGenC, this generates C code
with undefined behavior.
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(16, -i, dtype="int32") <= 16:
A[i] = 0.0

expected = before


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit bd800c9

Please sign in to comment.