-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Arith] Simplification of ceil, log2, and left_shift #11646
Changes from 2 commits
fed5cc6
2cc2661
66687e6
86d8165
1e4e642
58b9b7e
8e112ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1640,13 +1640,27 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { | |
// the operator overload will eagerly constant fold. | ||
return op->args[0] << op->args[1]; | ||
} | ||
} else if (op->op.same_as(Op::Get("tir.ceil"))) { | ||
if (auto as_int = op->args[0].as<IntImmNode>()) { | ||
return cast(op->dtype, IntImm(as_int->dtype, as_int->value)); | ||
} else if (auto as_float = op->args[0].as<FloatImmNode>()) { | ||
return cast(op->dtype, FloatImm(as_float->dtype, std::ceil(as_float->value))); | ||
} | ||
} else if (op->op.same_as(Op::Get("tir.log2"))) { | ||
if (auto as_int = op->args[0].as<IntImmNode>()) { | ||
return cast(op->dtype, FloatImm(as_int->dtype, std::log2(as_int->value))); | ||
} else if (auto as_float = op->args[0].as<FloatImmNode>()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#mathematical-functions-appendix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Because this folding would be applied in the |
||
return cast(op->dtype, FloatImm(as_float->dtype, std::log2(as_float->value))); | ||
} | ||
} | ||
|
||
if (op->op.same_as(tir::builtin::likely())) { | ||
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } | ||
if (auto match = TryMatchLiteralConstraint(op->args[0])) { | ||
return match.value(); | ||
} | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a/b has negative
min_value
, does taking the max narrow the bound? shall we returnEverything
if the bound can't be proved?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this might depend on the target. For the
CodegenC
backend, it doesn't narrow the bounds for any legal use of<<
, because negative arguments are entirely undefined behavior anyways. Taking themax
here is a way to express that same constraint. That is, even if we can't prove that the argument is non-negative, its use in a bitshifting operator provides a constraint.That said, since my primary goal is to improve simplifications from
ceil_log2
, perhaps it would be better to look for aceil(log2(x))
call, and use that directly. The aggressive optimizations that C++ compilers make based on undefined behavior reasoning are controversial for a reason, and I'd like to avoid introducing similar logic in TVM unless required.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given simplification logic like this one are used everywhere, it does merit an extra care. Specialization might make sense to be on the safe side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I had assumed that this was safe, because a similar compile-time check is used for the constant folding in
operator<<
, but in that case the left shift would be entirely removed. By removing a conditional, but still allowing the left shift to remain, we're asserting that we are handling the left shift identically to how it will be at runtime.I've updated the handling of left shift to return
Everything
in the case of potentially negative arguments, and added specific handling forceil(log2(x))
.