Skip to content

Commit

Permalink
ConstIntBound with memorization
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Apr 2, 2020
1 parent 458f0b3 commit e960449
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
8 changes: 8 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ class ConstIntBoundAnalyzer {
*/
ConstIntBound operator()(const PrimExpr& expr);

/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);

/*!
* \brief Update constant int bound information of var.
*
Expand Down
22 changes: 22 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ class ConstIntBoundAnalyzer::Impl :
res = Intersect(res, info.bound);
}
}
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
return res;
}

Expand Down Expand Up @@ -349,10 +360,13 @@ class ConstIntBoundAnalyzer::Impl :
}

private:
friend class ConstIntBoundAnalyzer;
// internal variable map
std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
Expand Down Expand Up @@ -545,6 +559,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return ConstIntBound(ret.min_value, ret.max_value);
}

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
return ConstIntBound(ret.min_value, ret.max_value);
}

void ConstIntBoundAnalyzer::Update(const Var& var,
const ConstIntBound& info,
bool override) {
Expand Down
9 changes: 7 additions & 2 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ class DataTypeVisitor final : public StmtExprVisitor {
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
ConstIntBound bound = analyzer_.const_int_bound(e);
const PrimExprNode* op = e.as<PrimExprNode>();
if (bound_.find(op) == bound_.end()) {
analyzer_.const_int_bound(e, &bound_);
}
ConstIntBound bound = bound_[op];
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
Expand Down Expand Up @@ -183,6 +187,8 @@ class DataTypeVisitor final : public StmtExprVisitor {
int target_bits_;
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
// the memorized bound generated by ConstIntBoundAnalyzer
std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
};

class DataTypeRewriter : public StmtExprMutator {
Expand Down Expand Up @@ -394,7 +400,6 @@ namespace transform {
Pass NarrowDataType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
// TODO(@hzfan): should Target be Attr here, with target_bits inferred from it?
IntImm target_bits = f->GetAttr<IntImm>("target_bits");
CHECK(target_bits.defined())
<< "NarrowDataType: Require the target_bits";
Expand Down

0 comments on commit e960449

Please sign in to comment.