-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Arith] Simplification of ceil, log2, and left_shift #11646
Conversation
These expressions are introduced in `topi.math.ceil_log2`, and can otherwise be propagated through to the generated kernel.
Previously, only right shift was handled. These left shifts are used in the `cuda.sort` implementation.
Refactored while debugging breakage of tests in apache#11646. Submitting as a separate PR, as it isn't necessary or related to the primary changes in that PR.
// means we couldn't prove that the inputs were positive. | ||
a.min_value = std::max(int64_t(0), a.min_value); | ||
b.min_value = std::max(int64_t(0), b.min_value); | ||
return BinaryOpBoundary(a, b, InfAwareLeftShift); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a/b has negative min_value
, does taking the max narrow the bound? shall we return Everything
if the bound can't be proved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this might depend on the target. For the CodegenC
backend, it doesn't narrow the bounds for any legal use of <<
, because negative arguments are entirely undefined behavior anyways. Taking the max
here is a way to express that same constraint. That is, even if we can't prove that the argument is non-negative, its use in a bitshifting operator provides a constraint.
That said, since my primary goal is to improve simplifications from ceil_log2
, perhaps it would be better to look for a ceil(log2(x))
call, and use that directly. The aggressive optimizations that C++ compilers make based on undefined behavior reasoning are controversial for a reason, and I'd like to avoid introducing similar logic in TVM unless required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given simplification logic like this one are used everywhere, it does merit an extra care. Specialization might make sense to be on the safe side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I had assumed that this was safe, because a similar compile-time check is used for the constant folding in operator<<
, but in that case the left shift would be entirely removed. By removing a conditional, but still allowing the left shift to remain, we're asserting that we are handling the left shift identically to how it will be at runtime.
I've updated the handling of left shift to return Everything
in the case of potentially negative arguments, and added specific handling for ceil(log2(x))
.
src/arith/rewrite_simplify.cc
Outdated
} else if (op->op.same_as(Op::Get("tir.log2"))) { | ||
if (auto as_int = op->args[0].as<IntImmNode>()) { | ||
return cast(op->dtype, FloatImm(as_int->dtype, std::log2(as_int->value))); | ||
} else if (auto as_float = op->args[0].as<FloatImmNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#mathematical-functions-appendix
Certain standard math functions have ulp errors on different devices, could the folding optimization be target aware?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Because this folding would be applied in the tvm.lower
path where the target may be unknown, it would be tricky to make it be target-aware. Since my main goal was for simplifying integer arguments, where ceil(log2(n))
is used for loop iteration bounds, we could get similar effects by restricting this folding to integer arguments to ceil(log2(n))
, which would be the same regardless of rounding differences.
Refactored while debugging breakage of tests in apache#11646. Submitting as a separate PR, as it isn't necessary or related to the primary changes in that PR.
Per @wrongtest's request, to avoid rounding differences between different devices.
Refactored while debugging breakage of tests in #11646. Submitting as a separate PR, as it isn't necessary or related to the primary changes in that PR.
* [TIR] Simplify expressions using tir.ceil and tir.log2 These expressions are introduced in `topi.math.ceil_log2`, and can otherwise be propagated through to the generated kernel. * [Arith] Added left shift handling to ConstIntBoundsAnalyzer Previously, only right shift was handled. These left shifts are used in the `cuda.sort` implementation. * Update to avoid left shift of negative numbers * Updated rewriting of log2(x) to only occur in ceil(log2(x)) Per @wrongtest's request, to avoid rounding differences between different devices. * Avoid assumptions made of negative arguments to left-shift * Recognize bounds of int(ceil(log2(arg)))
The GPU schedule for
topi.cuda.schedule_sort
includes expressions that usetir.ceil
,tir.log2
, andtir.shift_left
that remain in the TIR representation through the entire lowering flow. (For example,@tir.shift_left(2i64, (i_0 + cast(int64, cast(int32, @tir.ceil(@tir.log2(128f64, dtype=float64), dtype=float64)))), dtype=int64)
.) However, these expressions are not currently simplified byarith::Analyzer
.This PR introduces two changes, such that these expressions can be simplified.
ceil(log2(constant))
into a constant integer.x << y
, for use in proving conditionals.