Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Constraint-aware ConstIntBound, Enhance CanonicalSimplify #3132

Merged
merged 2 commits into from
May 4, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/arithmetic/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -453,6 +453,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<SplitExprNode>()) {
return GetRef<SplitExpr>(op);
}
if (const auto* op = expr.as<SumExprNode>()) {
if (op->base == 0 && op->args.size() == 1) return op->args[0];
}
if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
expr = op->Normalize();
}
Expand Down Expand Up @@ -764,6 +767,16 @@ Mutate_(const Mod* op, const Expr& self) {
}
}
}
// Simplify the offset constant if necessary.
// (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
auto cbound = parent_->const_int_bound(Normalize(a));
int64_t new_base = psum->base % cval;
if (cbound->min_value >= 0 &&
cbound->min_value - psum->base + new_base >= 0) {
SumExpr sum_expr(std::move(a.node_));
sum_expr.CopyOnWrite()->base = new_base;
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval);
}
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
Expand Down
77 changes: 74 additions & 3 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -25,6 +25,7 @@
#include <tvm/ir_functor_ext.h>
#include <algorithm>
#include "int_op_overflow.h"
#include "pattern_match.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -65,6 +66,19 @@ struct ConstIntBoundAnalyzer::Entry {
class ConstIntBoundAnalyzer::Impl :
public ExprFunctor<ConstIntBoundAnalyzer::Entry(const Expr&)> {
public:
/*! \brief additional bound info about expr \in bound */
struct BoundInfo {
/*! \brief The expr */
Expr expr;
/*! \brief The additional bound */
Entry bound;

BoundInfo() {}
BoundInfo(Expr expr, Entry bound)
: expr(expr), bound(bound) {
}
};

void Bind(const Var& var, const Range& range) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Expand Down Expand Up @@ -99,6 +113,18 @@ class ConstIntBoundAnalyzer::Impl :
static_cast<const ir::BaseExprNode*>(op)->type);
}

Entry VisitExpr(const Expr& expr) final {
Entry res = ExprFunctor::VisitExpr(expr);
// a linear search over additional info
// assume we won't have a lot of conditions
for (const BoundInfo& info : additional_info_) {
if (ir::Equal(expr, info.expr)) {
res = Intersect(res, info.bound);
}
}
return res;
}

Entry VisitExpr_(const Cast* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->type);
Expand Down Expand Up @@ -243,9 +269,24 @@ class ConstIntBoundAnalyzer::Impl :
}
}

std::function<void()> EnterConstraint(const Expr& constraint) {
std::vector<BoundInfo> info = DetectBoundInfo(constraint);
if (info.size() == 0) return nullptr;
size_t old_size = additional_info_.size();
additional_info_.insert(additional_info_.end(), info.begin(), info.end());
size_t new_size = old_size + info.size();
auto frecover = [old_size, new_size, this]() {
CHECK_EQ(additional_info_.size(), new_size);
additional_info_.resize(old_size);
};
return frecover;
}

private:
// internal variable map
std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
Expand Down Expand Up @@ -387,6 +428,36 @@ class ConstIntBoundAnalyzer::Impl :
}
return ret;
}

/*!
* \brief Detect additional constant bound from cond, if any
* \param cond The constraint condition.
* \return List of detected bounds.
*/
static std::vector<BoundInfo> DetectBoundInfo(const Expr& cond) {
PVar<Expr> x, y;
PVar<Integer> c;
// NOTE: canonical form always use <= or <
if ((c <= x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
}
if ((c < x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))};
}
if ((x <= c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))};
}
if ((x < c).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))};
}
if ((x && y).Match(cond)) {
auto ret1 = DetectBoundInfo(x.Eval());
auto ret2 = DetectBoundInfo(y.Eval());
ret1.insert(ret1.end(), ret2.begin(), ret2.end());
return ret1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need the x == c (and c == x) case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very sure if this is a case that we will need as frequently, so I decided to leave it out after some thought. We can revisit later. This might be better suited for a condition in a simplifier. Most conditions like x % 3== 1 is detected by modular analysis.

return {};
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
Expand All @@ -405,7 +476,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
}

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) {
return nullptr;
return impl_->EnterConstraint(constraint);
}

ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent)
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
60 changes: 51 additions & 9 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -1197,14 +1197,32 @@ Mutate_(const Or* op, const Expr& self) {

Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Select>();
if (is_zero(op->condition)) {
return op->false_value;
Expr cond = Mutate(op->condition);
Expr true_value, false_value;
{
ConstraintContext constraint(parent_, cond);
true_value = Mutate(op->true_value);
}
{
ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->false_value);
}
if (is_one(op->condition)) {
return op->true_value;
if (is_zero(cond)) {
return false_value;
}
if (is_one(cond)) {
return true_value;
}
// normal path
Expr ret;
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
ret = self;
} else {
ret = Select::make(cond, true_value, false_value);
}
op = ret.as<Select>();
// Pattern var to match any expression
PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y);
Expand All @@ -1213,7 +1231,31 @@ Mutate_(const Select* op, const Expr& self) {

Expr RewriteSimplifier::Impl::
Mutate_(const Call* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
// add condition context to if_then_else
Expr ret;
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = Mutate(op->args[0]);
Expr true_value, false_value;
{
ConstraintContext constraint(parent_, cond);
true_value = Mutate(op->args[1]);
}
{
ConstraintContext constraint(parent_, Mutate(Not::make(cond)));
false_value = Mutate(op->args[2]);
}
tqchen marked this conversation as resolved.
Show resolved Hide resolved
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
ret = self;
} else {
ret = Call::make(op->type, op->name,
{cond, true_value, false_value},
op->call_type);
}
} else {
ret = IRMutator::Mutate_(op, self);
}
op = ret.as<Call>();
if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
return op->args[0];
Expand Down
24 changes: 23 additions & 1 deletion tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self):

def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
assert tvm.ir_pass.Equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected)


def test_mul_sum_simplify():
Expand Down Expand Up @@ -157,7 +157,29 @@ def test_reduce_simplify():
ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k))


def test_simplify_if_then_else():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
# simplification that takes condition into account.
res = tvm.if_then_else((x * 4 + y) >= 466036,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y)
expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
(((x*4) + y) - 4) % 16,
x), y)
ck.verify(res, expected)
# can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
ck.verify(res, ck.analyzer.canonical_simplify(expected))


if __name__ == "__main__":
test_simplify_if_then_else()
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()
Expand Down