Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplification of ceil, log2, and left_shift #11646

Merged
merged 7 commits into from
Jun 21, 2022

Conversation

Lunderberg
Copy link
Contributor

The GPU schedule for topi.cuda.schedule_sort includes expressions that use tir.ceil, tir.log2, and tir.shift_left that remain in the TIR representation through the entire lowering flow. (For example, @tir.shift_left(2i64, (i_0 + cast(int64, cast(int32, @tir.ceil(@tir.log2(128f64, dtype=float64), dtype=float64)))), dtype=int64).) However, these expressions are not currently simplified by arith::Analyzer.

This PR introduces two changes, such that these expressions can be simplified.

  • Simplify ceil(log2(constant)) into a constant integer.
  • Identify bounds on x << y, for use in proving conditionals.

These expressions are introduced in `topi.math.ceil_log2`, and can
otherwise be propagated through to the generated kernel.
Previously, only right shift was handled.  These left shifts are
used in the `cuda.sort` implementation.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jun 9, 2022
Refactored while debugging breakage of tests in
apache#11646.  Submitting as a separate
PR, as it isn't necessary or related to the primary changes in that
PR.
// means we couldn't prove that the inputs were positive.
a.min_value = std::max(int64_t(0), a.min_value);
b.min_value = std::max(int64_t(0), b.min_value);
return BinaryOpBoundary(a, b, InfAwareLeftShift);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a/b has negative min_value, does taking the max narrow the bound? shall we return Everything if the bound can't be proved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this might depend on the target. For the CodegenC backend, it doesn't narrow the bounds for any legal use of <<, because negative arguments are entirely undefined behavior anyways. Taking the max here is a way to express that same constraint. That is, even if we can't prove that the argument is non-negative, its use in a bitshifting operator provides a constraint.

That said, since my primary goal is to improve simplifications from ceil_log2, perhaps it would be better to look for a ceil(log2(x)) call, and use that directly. The aggressive optimizations that C++ compilers make based on undefined behavior reasoning are controversial for a reason, and I'd like to avoid introducing similar logic in TVM unless required.

Copy link
Member

@tqchen tqchen Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given simplification logic like this one are used everywhere, it does merit an extra care. Specialization might make sense to be on the safe side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I had assumed that this was safe, because a similar compile-time check is used for the constant folding in operator<<, but in that case the left shift would be entirely removed. By removing a conditional, but still allowing the left shift to remain, we're asserting that we are handling the left shift identically to how it will be at runtime.

I've updated the handling of left shift to return Everything in the case of potentially negative arguments, and added specific handling for ceil(log2(x)).

} else if (op->op.same_as(Op::Get("tir.log2"))) {
if (auto as_int = op->args[0].as<IntImmNode>()) {
return cast(op->dtype, FloatImm(as_int->dtype, std::log2(as_int->value)));
} else if (auto as_float = op->args[0].as<FloatImmNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#mathematical-functions-appendix
Certain standard math functions have ulp errors on different devices, could the folding optimization be target aware?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Because this folding would be applied in the tvm.lower path where the target may be unknown, it would be tricky to make it be target-aware. Since my main goal was for simplifying integer arguments, where ceil(log2(n)) is used for loop iteration bounds, we could get similar effects by restricting this folding to integer arguments to ceil(log2(n)), which would be the same regardless of rounding differences.

Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jun 10, 2022
Refactored while debugging breakage of tests in
apache#11646.  Submitting as a separate
PR, as it isn't necessary or related to the primary changes in that
PR.
Per @wrongtest's request, to avoid rounding differences between
different devices.
masahi pushed a commit that referenced this pull request Jun 13, 2022
Refactored while debugging breakage of tests in
#11646.  Submitting as a separate
PR, as it isn't necessary or related to the primary changes in that
PR.
@vinx13 vinx13 merged commit bd800c9 into apache:main Jun 21, 2022
@Lunderberg Lunderberg deleted the simplify_ceil_log2 branch June 21, 2022 19:32
blackkker pushed a commit to blackkker/tvm that referenced this pull request Jul 7, 2022
* [TIR] Simplify expressions using tir.ceil and tir.log2

These expressions are introduced in `topi.math.ceil_log2`, and can
otherwise be propagated through to the generated kernel.

* [Arith] Added left shift handling to ConstIntBoundsAnalyzer

Previously, only right shift was handled.  These left shifts are
used in the `cuda.sort` implementation.

* Update to avoid left shift of negative numbers

* Updated rewriting of log2(x) to only occur in ceil(log2(x))

Per @wrongtest's request, to avoid rounding differences between
different devices.

* Avoid assumptions made of negative arguments to left-shift

* Recognize bounds of int(ceil(log2(arg)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants