Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplification of ceil, log2, and left_shift #11646

Merged
merged 7 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ class ConstIntBoundAnalyzer::Impl

if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tir::builtin::shift_left())) {
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else {
Expand Down Expand Up @@ -341,6 +343,12 @@ class ConstIntBoundAnalyzer::Impl
}
}

Entry VisitLeftShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
return BinaryOpBoundary(a, b, InfAwareLeftShift);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a/b has negative min_value, does taking the max narrow the bound? shall we return Everything if the bound can't be proved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this might depend on the target. For the CodegenC backend, it doesn't narrow the bounds for any legal use of <<, because negative arguments are entirely undefined behavior anyways. Taking the max here is a way to express that same constraint. That is, even if we can't prove that the argument is non-negative, its use in a bitshifting operator provides a constraint.

That said, since my primary goal is to improve simplifications from ceil_log2, perhaps it would be better to look for a ceil(log2(x)) call, and use that directly. The aggressive optimizations that C++ compilers make based on undefined behavior reasoning are controversial for a reason, and I'd like to avoid introducing similar logic in TVM unless required.

Copy link
Member

@tqchen tqchen Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given simplification logic like this one are used everywhere, it does merit an extra care. Specialization might make sense to be on the safe side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I had assumed that this was safe, because a similar compile-time check is used for the constant folding in operator<<, but in that case the left shift would be entirely removed. By removing a conditional, but still allowing the left shift to remain, we're asserting that we are handling the left shift identically to how it will be at runtime.

I've updated the handling of left shift to return Everything in the case of potentially negative arguments, and added specific handling for ceil(log2(x)).

}

Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
Expand Down Expand Up @@ -509,7 +517,33 @@ class ConstIntBoundAnalyzer::Impl
return floordiv(x, y);
}
/*!
* \brief Compute x / y, aware of inf.
* \brief Compute x << y, aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
static int64_t InfAwareLeftShift(int64_t x, int64_t y) {
if (x == kPosInf || x == kNegInf) return x;

// Can be replaced with std::bit_width in C++20
auto bit_width = [](int64_t as_signed) {
uint64_t val = std::abs(as_signed);
int num_bits = 0;
while (val) {
++num_bits;
val >>= 1;
}
return num_bits;
};
int x_bits = bit_width(x);
if (x_bits + y < 64) {
return x << y;
} else {
return kPosInf;
}
}
/*!
* \brief Compute x >> y, aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
Expand Down
14 changes: 14 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1640,13 +1640,27 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// the operator overload will eagerly constant fold.
return op->args[0] << op->args[1];
}
} else if (op->op.same_as(Op::Get("tir.ceil"))) {
if (auto as_int = op->args[0].as<IntImmNode>()) {
return cast(op->dtype, IntImm(as_int->dtype, as_int->value));
} else if (auto as_float = op->args[0].as<FloatImmNode>()) {
return cast(op->dtype, FloatImm(as_float->dtype, std::ceil(as_float->value)));
}
} else if (op->op.same_as(Op::Get("tir.log2"))) {
if (auto as_int = op->args[0].as<IntImmNode>()) {
return cast(op->dtype, FloatImm(as_int->dtype, std::log2(as_int->value)));
} else if (auto as_float = op->args[0].as<FloatImmNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#mathematical-functions-appendix
Certain standard math functions have ulp errors on different devices, could the folding optimization be target aware?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Because this folding would be applied in the tvm.lower path where the target may be unknown, it would be tricky to make it be target-aware. Since my main goal was for simplifying integer arguments, where ceil(log2(n)) is used for loop iteration bounds, we could get similar effects by restricting this folding to integer arguments to ceil(log2(n)), which would be the same regardless of rounding differences.

return cast(op->dtype, FloatImm(as_float->dtype, std::log2(as_float->value)));
}
}

if (op->op.same_as(tir::builtin::likely())) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (auto match = TryMatchLiteralConstraint(op->args[0])) {
return match.value();
}
}

return ret;
}

Expand Down
66 changes: 66 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,5 +391,71 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
A[i, j] = 2


class TestCeilLog2Float(BaseBeforeAfter):
"""Simplify expressions resulting from topi.math.ceil_log2"""

@T.prim_func
def before(A: T.Buffer[1, "float32"]):
A[0] = T.ceil(T.log2(14.0, dtype="float32"), dtype="float32")

@T.prim_func
def expected(A: T.Buffer[1, "float32"]):
A[0] = 4.0


class TestCeilLog2Int(BaseBeforeAfter):
"""Simplify expressions resulting from topi.math.ceil_log2"""

@T.prim_func
def before(A: T.Buffer[1, "int32"]):
A[0] = T.cast(
T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32"
)

@T.prim_func
def expected(A: T.Buffer[1, "int32"]):
A[0] = 4


class TestLeftShiftLowerBound(BaseBeforeAfter):
"""Integer bounds are propagated through left shift

min(1 << i) = 1 << min(i)
= 1 << 0
= 1
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(1, i, dtype="int32") >= 1:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


class TestLeftShiftUpperBound(BaseBeforeAfter):
"""Integer bounds are propagated through left shift

max(31 << i) = 31 << max(i)
= 31 << 15
= 1015808
"""

@T.prim_func
def before(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
if T.shift_left(31, i, dtype="int32") <= 1015808:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0


if __name__ == "__main__":
tvm.testing.main()