diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index e09ff1d65a5e..e6f37f453aa5 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -55,17 +55,25 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr real_condition = condition; + if (auto call = condition.as()) { + if (call->is_intrinsic(CallNode::likely)) { + real_condition = call->args[0]; + } + } + Stmt then_case, else_case; { - With ctx(analyzer_, condition); + With ctx(analyzer_, real_condition); then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { - With ctx(analyzer_, analyzer_->rewrite_simplify(NotNode::make(condition))); + With ctx(analyzer_, + analyzer_->rewrite_simplify(NotNode::make(real_condition))); else_case = this->VisitStmt(op->else_case); } - if (is_one(condition)) return then_case; - if (is_zero(condition)) { + if (is_one(real_condition)) return then_case; + if (is_zero(real_condition)) { if (else_case.defined()) { return else_case; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 3b8ccfb01a93..223b2e6c5f5a 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -211,7 +211,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { size_t old_literal_size = literal_constraints_.size(); - literal_constraints_.push_back(constraint); + // we will compare the already simplified result with the constraint, + // so simplify the constarint as well + literal_constraints_.push_back(operator()(constraint)); size_t new_literal_size = literal_constraints_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { CHECK_EQ(literal_constraints_.size(), new_literal_size); diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index bf5398245c50..48d0849bd1ee 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -52,6 +52,26 @@ def test_thread_extent_simplify(): assert isinstance(body.body.body.body, tvm.tir.Store) +def test_if_likely(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + n = te.size_var("n") + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + ib.scope_attr(tx, "thread_extent", 32) + ib.scope_attr(ty, "thread_extent", 32) + with ib.if_scope(ib.likely(tx * 32 + ty < n)): + with ib.if_scope(ib.likely(tx * 32 + ty < n)): + A[tx] = C[tx * 32 + ty] + body = ib.get() + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C, n], body)) + body = tvm.tir.transform.Simplify()(mod)["main"].body + assert isinstance(body.body.body, tvm.tir.IfThenElse) + assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse) + + def test_basic_likely_elimination(): n = te.size_var('n') X = te.placeholder(shape=(n,), name="x") @@ -110,5 +130,6 @@ def sls(n, d): if __name__ == "__main__": test_stmt_simplify() test_thread_extent_simplify() + test_if_likely() test_basic_likely_elimination() test_complex_likely_elimination()