Skip to content

Commit

Permalink
Migrate simplifier to new infra. (apache#3368)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and Wei Chen committed Jul 11, 2019
1 parent fe3f274 commit f7eba84
Show file tree
Hide file tree
Showing 19 changed files with 175 additions and 154 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
Expand Down
9 changes: 6 additions & 3 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
Expand All @@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down
1 change: 1 addition & 0 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
Expr Analyzer::Simplify(const Expr& expr) {
if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
if (is_const(res)) return res;
res = this->canonical_simplify(res);
return res;
}
Expand Down
145 changes: 84 additions & 61 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
void Deduce();

void Visit(const NodeRef& e) final {
if (!success) return;
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
success_ = false;
return;
}
}
Expand All @@ -111,62 +111,84 @@ class BoundDeducer: public IRVisitor {

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
result_ += op->b;
} else {
result -= op->a;
result = - result;
is_greater = !is_greater;
result_ -= op->a;
result_ = - result_;
is_greater_ = !is_greater_;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;

SignType sign;
SignType sign_operand;
if (operand.type().is_uint()) {
sign = kPositive;
sign_operand = kPositive;
} else {
sign = expr_map_[operand].sign_type();
sign_operand = expr_map_[operand].sign_type();
}

if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
if (sign_operand == SignType::kNegative) {
is_greater_ = !is_greater_;
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
success_ = false;
return;
}

// always use relax bound
bool divided = can_prove(result % operand == 0);
result = result / operand;
// since system will round down when not divided
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
// eg. a <= 2/4 -> a <= 0
// eg. a <= 0/4 -> a <= 0
// so just fix for not divided and is_greater
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if (is_greater && !divided) {
result += 1;
bool divided = analyzer_.CanProve(result_ % operand == 0);

result_ = result_ / operand;

if (!divided) {
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();

if (is_greater_) {
result_ += 1;
} else {
// NOTE: this is a bit sutble hack.
//
// condition:
// - x * operand <= result
// - operand > 0
// - x >= 0
//
// Then it is fine to deduce that x <= result / operand.
// - if result > 0, this division round down
// - if result < 0, (result / operand) rounds up and may violate the constraint
// however, given that x is always non-negative,
// it is fine to have this relaxed bound, given that the user of deduce bound
// will respect the bound of x
//
// TODO(tvm-team): think about a better API to incorporate constraint of x.
// e.g. specify an interval of x and return a bound
// that is in the interval and satisfies the condition.
if (target_is_non_neg && sign_operand == kPositive) {
// do nothing
} else {
result_ -= 1;
}
}
}

Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};
bool success{true};
Expr result_;
bool is_greater_{true};
bool success_{true};

private:
void Init();
Expand All @@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
};

class BoundDeduceInputChecker: public IRVisitor {
Expand All @@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {

void BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;
if (!checker.Check(this)) success_ = false;
Transform();
}

Expand All @@ -211,93 +235,92 @@ void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
is_greater = true;
is_greater_ = true;
expr_ = op->b;
result = op->a + 1;
result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
is_greater = false;
is_greater_ = false;
expr_ = op->a;
result = op->b - 1;
result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
is_greater = true;
is_greater_ = true;
expr_ = op->b;
result = op->a;
result_ = op->a;
} else {
is_greater = false;
is_greater_ = false;
expr_ = op->a;
result = op->b;
result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
is_greater = false;
is_greater_ = false;
expr_ = op->b;
result = op->a - 1;
result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
is_greater = true;
is_greater_ = true;
expr_ = op->a;
result = op->b + 1;
result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
is_greater = false;
is_greater_ = false;
expr_ = op->b;
result = op->a;
result_ = op->a;
} else {
is_greater = true;
is_greater_ = true;
expr_ = op->a;
result = op->b;
result_ = op->b;
}
} else {
success = false;
success_ = false;
}
}

void BoundDeducer::Deduce() {
Init();
if (!success) return;
if (!success_) return;
Relax();
if (!success) return;
if (!success_) return;
// get the path
path_ = GetPath(target_, expr_);
if (!path_.size()) {
success = false;
success_ = false;
return;
}

expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);
}

void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result, relax_map_);
IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) {
success = false;
success_ = false;
return;
}
expr_ = is_greater ? a.min() : a.max();
result = is_greater ? b.max() : b.min();
expr_ = is_greater_ ? a.min() : a.max();
result_ = is_greater_ ? b.max() : b.min();
}

IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
if (d.is_greater_) {
min = d.result_;
} else {
max = d.result;
max = d.result_;
}
return IntSet::interval(min, max);
}
Expand Down
7 changes: 4 additions & 3 deletions src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
if (pa && pb) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
Expand Down
1 change: 0 additions & 1 deletion src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));


TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
Expand Down
Loading

0 comments on commit f7eba84

Please sign in to comment.