Skip to content

Commit

Permalink
handle likely in IRMutatorWithAnalyzer (#5665)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored May 25, 2020
1 parent 0833b07 commit 4c976a6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,25 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {

Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr real_condition = condition;
if (auto call = condition.as<CallNode>()) {
if (call->is_intrinsic(CallNode::likely)) {
real_condition = call->args[0];
}
}

Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, condition);
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(NotNode::make(condition)));
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(NotNode::make(real_condition)));
else_case = this->VisitStmt(op->else_case);
}
if (is_one(condition)) return then_case;
if (is_zero(condition)) {
if (is_one(real_condition)) return then_case;
if (is_zero(real_condition)) {
if (else_case.defined()) {
return else_case;
}
Expand Down
4 changes: 3 additions & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {

std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
size_t old_literal_size = literal_constraints_.size();
literal_constraints_.push_back(constraint);
// we will compare the already simplified result with the constraint,
// so simplify the constarint as well
literal_constraints_.push_back(operator()(constraint));
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
CHECK_EQ(literal_constraints_.size(), new_literal_size);
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ def test_thread_extent_simplify():
assert isinstance(body.body.body.body, tvm.tir.Store)


def test_if_likely():
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
C = ib.pointer("float32", name="C")
n = te.size_var("n")
tx = te.thread_axis("threadIdx.x")
ty = te.thread_axis("threadIdx.y")
ib.scope_attr(tx, "thread_extent", 32)
ib.scope_attr(ty, "thread_extent", 32)
with ib.if_scope(ib.likely(tx * 32 + ty < n)):
with ib.if_scope(ib.likely(tx * 32 + ty < n)):
A[tx] = C[tx * 32 + ty]
body = ib.get()
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A, C, n], body))
body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body.body, tvm.tir.IfThenElse)
assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse)


def test_basic_likely_elimination():
n = te.size_var('n')
X = te.placeholder(shape=(n,), name="x")
Expand Down Expand Up @@ -110,5 +130,6 @@ def sls(n, d):
if __name__ == "__main__":
test_stmt_simplify()
test_thread_extent_simplify()
test_if_likely()
test_basic_likely_elimination()
test_complex_likely_elimination()

0 comments on commit 4c976a6

Please sign in to comment.