Skip to content

Commit

Permalink
[ARITH] DeduceBound (#40)
Browse files Browse the repository at this point in the history
* [PYTHON/API] Add compare and logic build-in op for Expr

* remove 'and', 'or'

* add deducer

* [WIP] bound_deducer.cc

* move IntervalSet and StrideSet into int_set_internal.h

* add multiple failure for VariablePathFinder, add EvalSign

* consider round in deduce, add success flag

* remove Visit_(Div)

* add comment, update HalideIR

* expose intset to python

* check the sign of every expr

* set return type as ExprSignType

* fine tune

* add min & max python api for interval set

* support for conditional expr

* refactor test

* add checker for BoundDeducer

* add python check test

* fix

* fix

* change range to interval; remove converter

* remove converter declaration

* remove int_set_internal.h
  • Loading branch information
Ziheng Jiang authored and tqchen committed Feb 17, 2017
1 parent d114dfc commit 5198c10
Show file tree
Hide file tree
Showing 10 changed files with 555 additions and 48 deletions.
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 642ae5 to e68ae6
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._ctypes._node import register_node

from . import tensor
from . import arith
from . import expr
from . import stmt
from . import make
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ctypes/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def _init_api_functions(root_namespace):
module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_arith_": sys.modules["%s.arith" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# pylint: disable=protected-access, no-member
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal

@register_node
class IntSet(NodeBase):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _api_internal._IntSetIsNothing(self)

def is_everything(self):
"""Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self)

@register_node
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)

def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)

@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass

50 changes: 50 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to arith
* \file api_arith.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"

namespace tvm {
namespace arith {

TVM_REGISTER_API(_arith_intset_single_point)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});

TVM_REGISTER_API(_arith_intset_interval)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});

TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
});

TVM_REGISTER_API(_IntervalSetGetMin)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});

TVM_REGISTER_API(_IntervalSetGetMax)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});

TVM_REGISTER_API(_IntSetIsNothing)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});

TVM_REGISTER_API(_IntSetIsEverything)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});

} // namespace arith
} // namespace tvm
229 changes: 229 additions & 0 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include <unordered_map>
#include "./int_set.h"

namespace tvm {
namespace arith {

using namespace ir;
using Halide::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
public:
explicit VariablePathFinder(Var target) : target_(target) {}

void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());

if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node);
if (!found_) path_.pop_back();
}

std::vector<const Node*> path_;

private:
bool found_{false};
Var target_;
std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.path_;
}

class BoundDeduceIntputChecker;

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Var target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_(target), expr_(expr), dom_map_(dom_map) {}

bool Init();
void Deduce();

void Visit(const NodeRef& e) final {
if (!success) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
return;
}
}

void Visit_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
} else {
result -= op->a;
result = - result;
is_greater = !is_greater;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;

SignType sign;
if (operand.type().is_uint()) {
sign = kPositive;
} else {
sign = expr_map_[operand].sign_type();
}

if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
return;
}

// always use relax bound
result = result / operand + (is_greater ? 1 : -1);
Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};
bool is_equal{true};
bool success{true};

private:
Var target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_;
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
};

class BoundDeduceInputChecker: public IRVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
Visit(deducer_->expr_);
return target_count == 1;
}

void Visit(const NodeRef& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
}

private:
BoundDeducer* deducer_;
size_t target_count{0};
};

bool BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;

if (const LT* op = expr_.as<LT>()) {
is_greater = false;
is_equal = false;
expr_ = op->a;
result = op->b;
} else if (const LE* op = expr_.as<LE>()) {
is_greater = false;
is_equal = true;
expr_ = op->a;
result = op->b;
} else if (const GT* op = expr_.as<GT>()) {
is_greater = true;
is_equal = false;
expr_ = op->a;
result = op->b;
} else if (const GE* op = expr_.as<GE>()) {
is_greater = true;
is_equal = true;
expr_ = op->a;
result = op->b;
} else {
success = false;
}
return success;
}

void BoundDeducer::Deduce() {
Init();
if (!success) return;

// get the path
path_ = GetPath(target_, expr_);
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_);

Visit(expr_);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Var v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
}
BoundDeducer d(v, e, dmap);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
if (d.is_greater) {
min = d.is_equal ? d.result : d.result + 1;
} else {
max = d.is_equal ? d.result : d.result - 1;
}
return IntSet::interval(min, max);
}

} // namespace arith
} // namespace tvm
Loading

0 comments on commit 5198c10

Please sign in to comment.