Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplification of ceil, log2, and left_shift #11646

Merged
merged 7 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,17 @@ class ConstIntBoundAnalyzer::Impl
}

Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry a;

// int(ceil(log2(cast(n,"float64")))) is used as the
// implementation of topi.math.ceil_log2, and appears in iteration
// bounds.
if (auto opt = FindCeilLog2Arg(op)) {
a = CeilLog2Bounds(opt.value());
} else {
a = VisitExpr(op->value);
}

Entry b = Everything(op->dtype);
return Intersect(a, b);
}
Expand Down Expand Up @@ -314,6 +324,8 @@ class ConstIntBoundAnalyzer::Impl

if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tir::builtin::shift_left())) {
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else {
Expand Down Expand Up @@ -341,6 +353,20 @@ class ConstIntBoundAnalyzer::Impl
}
}

Entry VisitLeftShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);

if (a.min_value < 0 || b.min_value < 0) {
// If either operand can negative, we may run into undefined
// behavior for some targets. In these cases, avoid making any
// assumptions about the result.
return Everything(op->dtype);
}

return BinaryOpBoundary(a, b, InfAwareLeftShift);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a/b has negative min_value, does taking the max narrow the bound? shall we return Everything if the bound can't be proved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this might depend on the target. For the CodegenC backend, it doesn't narrow the bounds for any legal use of <<, because negative arguments are entirely undefined behavior anyways. Taking the max here is a way to express that same constraint. That is, even if we can't prove that the argument is non-negative, its use in a bitshifting operator provides a constraint.

That said, since my primary goal is to improve simplifications from ceil_log2, perhaps it would be better to look for a ceil(log2(x)) call, and use that directly. The aggressive optimizations that C++ compilers make based on undefined behavior reasoning are controversial for a reason, and I'd like to avoid introducing similar logic in TVM unless required.

Copy link
Member

@tqchen tqchen Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given simplification logic like this one are used everywhere, it does merit an extra care. Specialization might make sense to be on the safe side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I had assumed that this was safe, because a similar compile-time check is used for the constant folding in operator<<, but in that case the left shift would be entirely removed. By removing a conditional, but still allowing the left shift to remain, we're asserting that we are handling the left shift identically to how it will be at runtime.

I've updated the handling of left shift to return Everything in the case of potentially negative arguments, and added specific handling for ceil(log2(x)).

}

Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
Expand Down Expand Up @@ -509,7 +535,33 @@ class ConstIntBoundAnalyzer::Impl
return floordiv(x, y);
}
/*!
* \brief Compute x / y, aware of inf.
* \brief Compute x << y, aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
static int64_t InfAwareLeftShift(int64_t x, int64_t y) {
if (x == kPosInf || x == kNegInf) return x;

// Can be replaced with std::bit_width in C++20
auto bit_width = [](int64_t as_signed) {
uint64_t val = std::abs(as_signed);
int num_bits = 0;
while (val) {
++num_bits;
val >>= 1;
}
return num_bits;
};
int x_bits = bit_width(x);
if (x_bits + y < 64) {
return x << y;
} else {
return kPosInf;
}
}
/*!
* \brief Compute x >> y, aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
Expand Down Expand Up @@ -609,6 +661,46 @@ class ConstIntBoundAnalyzer::Impl
}
return {};
}

/*!
* \brief Extract the argument from int(ceil(log2(arg)))
*
* This expression is used as the implementation of
* topi.math.ceil_log2, and can appear in iteration bounds.
*/
static Optional<PrimExpr> FindCeilLog2Arg(const CastNode* op) {
if (op->dtype.is_int()) {
if (auto as_call = op->value.as<CallNode>()) {
if (as_call->op.same_as(Op::Get("tir.ceil"))) {
PrimExpr ceil_arg = as_call->args[0];
if (auto arg_call = ceil_arg.as<CallNode>()) {
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
PrimExpr log_arg = arg_call->args[0];
return log_arg;
}
}
}
}
}
return NullOpt;
}

/*! \brief Propagate constraints through ceil(log2(arg))
*
* Helper function for CastNode visitor
*/
Entry CeilLog2Bounds(PrimExpr arg) {
if (auto as_float = arg.as<FloatImmNode>()) {
// A cast from int to float may have already been simplified
// out. Normally we don't inspect floating-point arguments, but here we can
int64_t val = std::ceil(std::log2(as_float->value));
return MakeBound(val, val);
} else {
Entry arg_bounds = VisitExpr(arg);
return MakeBound(std::ceil(std::log2(arg_bounds.min_value)),
std::ceil(std::log2(arg_bounds.max_value)));
}
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {
Expand Down
21 changes: 21 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1640,13 +1640,34 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// the operator overload will eagerly constant fold.
return op->args[0] << op->args[1];
}
} else if (op->op.same_as(Op::Get("tir.ceil"))) {
PrimExpr ceil_arg = op->args[0];
if (auto arg_int = op->args[0].as<IntImmNode>()) {
return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value));
} else if (auto arg_float = ceil_arg.as<FloatImmNode>()) {
return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value)));
} else if (auto arg_call = ceil_arg.as<CallNode>()) {
// ceil(log2(cast(n,"float64"))) is used as the implementation of
// topi.math.ceil_log2, and appears in iteration bounds.
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
PrimExpr log_arg = arg_call->args[0];
if (auto as_float = log_arg.as<FloatImmNode>()) {
// ceil(log2(n)) can be simplified, and should produce the
// same integer result regardless of the target's rounding
// conventions.
return FloatImm(op->dtype, std::ceil(std::log2(as_float->value)));
}
}
}
}

if (op->op.same_as(tir::builtin::likely())) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (auto match = TryMatchLiteralConstraint(op->args[0])) {
return match.value();
}
}

return ret;
}

Expand Down
107 changes: 107 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,5 +391,112 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
A[i, j] = 2


class TestCeilLog2Int(BaseBeforeAfter):
"""Simplify expressions resulting from topi.math.ceil_log2"""

@T.prim_func
def before(A: T.Buffer[1, "int32"]):
A[0] = T.cast(
T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32"
)

@T.prim_func
def expected(A: T.Buffer[1, "int32"]):
A[0] = 4


class TestLeftCeilLog2LowerBound(BaseBeforeAfter):
"""Integer bounds are propagated through topi.math.ceil_log2"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
x = T.cast(
T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"), dtype="float64"), dtype="float64"),
dtype="int32",
)
if x == 11:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftLowerBound(BaseBeforeAfter):
"""Integer bounds are propagated through left shift

min(1 << i) = 1 << min(i)
= 1 << 0
= 1
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(1, i, dtype="int32") >= 1:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftUpperBound(BaseBeforeAfter):
"""Integer bounds are propagated through left shift

max(31 << i) = 31 << max(i)
= 31 << 15
= 1015808
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(31, i, dtype="int32") <= 1015808:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftOfNegativeValue(BaseBeforeAfter):
"""No const int bounds of left shift of negative value.

This is target dependent, and does not currently have a specified
behavior in TIR. For example, in CodeGenC, this generates C code
with undefined behavior.
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if -64 <= T.shift_left(-i, 4, dtype="int32"):
A[i] = 0.0

expected = before


class TestLeftShiftByNegativeValue(BaseBeforeAfter):
"""No const int bounds of left shift by negative bit count.

This is target dependent, and does not currently have a specified
behavior in TIR. For example, in CodeGenC, this generates C code
with undefined behavior.
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(16, -i, dtype="int32") <= 16:
A[i] = 0.0

expected = before


if __name__ == "__main__":
tvm.testing.main()