-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [PYTHON/API] Add compare and logic build-in op for Expr * remove 'and', 'or' * add deducer * [WIP] bound_deducer.cc * move IntervalSet and StrideSet into int_set_internal.h * add multiple failure for VariablePathFinder, add EvalSign * consider round in deduce, add success flag * remove Visit_(Div) * add comment, update HalideIR * expose intset to python * check the sign of every expr * set return type as ExprSignType * fine tune * add min & max python api for interval set * support for conditional expr * refactor test * add checker for BoundDeducer * add python check test * fix * fix * change range to interval; remove converter * remove converter declaration * remove int_set_internal.h
- Loading branch information
Showing
10 changed files
with
555 additions
and
48 deletions.
There are no files selected for viewing
Submodule HalideIR
updated
from 642ae5 to e68ae6
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# pylint: disable=protected-access, no-member | ||
"""Arithmetic data structure and utility""" | ||
from __future__ import absolute_import as _abs | ||
from ._ctypes._node import NodeBase, register_node | ||
from . import _api_internal | ||
|
||
@register_node | ||
class IntSet(NodeBase): | ||
"""Represent a set of integer in one dimension.""" | ||
def is_nothing(self): | ||
"""Whether the set represent nothing""" | ||
return _api_internal._IntSetIsNothing(self) | ||
|
||
def is_everything(self): | ||
"""Whether the set represent everything""" | ||
return _api_internal._IntSetIsEverything(self) | ||
|
||
@register_node | ||
class IntervalSet(IntSet): | ||
"""Represent set of continuous interval""" | ||
def min(self): | ||
"""get the minimum value""" | ||
return _api_internal._IntervalSetGetMin(self) | ||
|
||
def max(self): | ||
"""get the maximum value""" | ||
return _api_internal._IntervalSetGetMax(self) | ||
|
||
@register_node | ||
class StrideSet(IntSet): | ||
"""Represent set of strided integers""" | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* Implementation of API functions related to arith | ||
* \file api_arith.cc | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir.h> | ||
#include <tvm/api_registry.h> | ||
#include "../arithmetic/int_set.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
TVM_REGISTER_API(_arith_intset_single_point) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntSet::single_point(args[0]); | ||
}); | ||
|
||
TVM_REGISTER_API(_arith_intset_interval) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntSet::interval(args[0], args[1]); | ||
}); | ||
|
||
TVM_REGISTER_API(_arith_DeduceBound) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = DeduceBound(args[0], args[1], args[2]); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntervalSetGetMin) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().min(); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntervalSetGetMax) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().max(); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntSetIsNothing) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().is_nothing(); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntSetIsEverything) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().is_everything(); | ||
}); | ||
|
||
} // namespace arith | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file bound_deducer.cc | ||
* \brief Utility to deduce bound of expression | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir_pass.h> | ||
#include <tvm/ir_visitor.h> | ||
#include <tvm/api_registry.h> | ||
#include <unordered_set> | ||
#include <unordered_map> | ||
#include "./int_set.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
using namespace ir; | ||
using Halide::Internal::Interval; | ||
|
||
// a visitor to find the path to the target variable | ||
// from a expression. | ||
class VariablePathFinder: public IRVisitor { | ||
public: | ||
explicit VariablePathFinder(Var target) : target_(target) {} | ||
|
||
void Visit(const NodeRef& node) final { | ||
if (visited_.count(node.get()) != 0) return; | ||
visited_.insert(node.get()); | ||
|
||
if (!found_) path_.push_back(node.get()); | ||
if (node.same_as(target_)) found_ = true; | ||
IRVisitor::Visit(node); | ||
if (!found_) path_.pop_back(); | ||
} | ||
|
||
std::vector<const Node*> path_; | ||
|
||
private: | ||
bool found_{false}; | ||
Var target_; | ||
std::unordered_set<const Node*> visited_; | ||
}; | ||
|
||
// get the path to the variable, | ||
// return empty vector to represent failure | ||
std::vector<const Node*> GetPath(Var target, Expr expr) { | ||
VariablePathFinder v(target); | ||
v.Visit(expr); | ||
return v.path_; | ||
} | ||
|
||
class BoundDeduceIntputChecker; | ||
|
||
// a visitor to deduce the bound of a variable from a expression | ||
class BoundDeducer: public IRVisitor { | ||
public: | ||
friend class BoundDeduceInputChecker; | ||
friend class Converter; | ||
BoundDeducer(Var target, Expr expr, | ||
const std::unordered_map<const Variable*, IntSet>& dom_map) | ||
: target_(target), expr_(expr), dom_map_(dom_map) {} | ||
|
||
bool Init(); | ||
void Deduce(); | ||
|
||
void Visit(const NodeRef& e) final { | ||
if (!success) return; | ||
if (e.get() == path_[iter_++]) { | ||
IRVisitor::Visit(e); | ||
} else { | ||
success = false; | ||
return; | ||
} | ||
} | ||
|
||
void Visit_(const LT* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const LE* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const GT* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const GE* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const Add* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
result -= left ? op->b : op->a; | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Sub* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
if (left) { | ||
result += op->b; | ||
} else { | ||
result -= op->a; | ||
result = - result; | ||
is_greater = !is_greater; | ||
} | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Mul* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
Expr operand = left ? op->b : op->a; | ||
|
||
SignType sign; | ||
if (operand.type().is_uint()) { | ||
sign = kPositive; | ||
} else { | ||
sign = expr_map_[operand].sign_type(); | ||
} | ||
|
||
if (sign == SignType::kNegative) { | ||
is_greater = !is_greater; | ||
} else if (sign == SignType::kUnknown) { | ||
// unable to get the sign of operand | ||
success = false; | ||
return; | ||
} | ||
|
||
// always use relax bound | ||
result = result / operand + (is_greater ? 1 : -1); | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
Expr result; | ||
bool is_greater{true}; | ||
bool is_equal{true}; | ||
bool success{true}; | ||
|
||
private: | ||
Var target_; | ||
Expr expr_; | ||
const std::unordered_map<const Variable*, IntSet>& dom_map_; | ||
ExprIntSetMap expr_map_; | ||
std::vector<const Node*> path_; | ||
size_t iter_{0}; | ||
}; | ||
|
||
class BoundDeduceInputChecker: public IRVisitor { | ||
public: | ||
bool Check(BoundDeducer* deducer) { | ||
deducer_ = deducer; | ||
Visit(deducer_->expr_); | ||
return target_count == 1; | ||
} | ||
|
||
void Visit(const NodeRef& e) final { | ||
if (e.same_as(deducer_->target_)) ++target_count; | ||
IRVisitor::Visit(e); | ||
} | ||
|
||
private: | ||
BoundDeducer* deducer_; | ||
size_t target_count{0}; | ||
}; | ||
|
||
bool BoundDeducer::Init() { | ||
BoundDeduceInputChecker checker; | ||
if (!checker.Check(this)) success = false; | ||
|
||
if (const LT* op = expr_.as<LT>()) { | ||
is_greater = false; | ||
is_equal = false; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else if (const LE* op = expr_.as<LE>()) { | ||
is_greater = false; | ||
is_equal = true; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else if (const GT* op = expr_.as<GT>()) { | ||
is_greater = true; | ||
is_equal = false; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else if (const GE* op = expr_.as<GE>()) { | ||
is_greater = true; | ||
is_equal = true; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else { | ||
success = false; | ||
} | ||
return success; | ||
} | ||
|
||
void BoundDeducer::Deduce() { | ||
Init(); | ||
if (!success) return; | ||
|
||
// get the path | ||
path_ = GetPath(target_, expr_); | ||
// get the sign of every subexpr | ||
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); | ||
|
||
Visit(expr_); | ||
} | ||
|
||
// assuming e >= 0, deduce the bound of variable from it. | ||
// return empty set to represent deduce failure. | ||
IntSet DeduceBound(Var v, Expr e, | ||
const Map<Var, IntSet>& dom_map) { | ||
std::unordered_map<const Variable*, IntSet> dmap; | ||
for (auto kv : dom_map) { | ||
dmap[kv.first.get()] = kv.second; | ||
} | ||
BoundDeducer d(v, e, dmap); | ||
d.Deduce(); | ||
if (!d.success) return IntSet::nothing(); | ||
Expr min = Interval::neg_inf, max = Interval::pos_inf; | ||
if (d.is_greater) { | ||
min = d.is_equal ? d.result : d.result + 1; | ||
} else { | ||
max = d.is_equal ? d.result : d.result - 1; | ||
} | ||
return IntSet::interval(min, max); | ||
} | ||
|
||
} // namespace arith | ||
} // namespace tvm |
Oops, something went wrong.