Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache PrimExpr instead of raw pointers in bound analyzer #5533

Merged
merged 1 commit into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ConstIntBound : public ObjectRef {
*/
class ConstIntBoundAnalyzer {
public:
using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectHash, ObjectEqual>;
/*!
* \brief analyze the expr
* \param expr The expression of interest.
Expand All @@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer {
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);

/*!
* \brief Update constant int bound information of var.
Expand Down
11 changes: 5 additions & 6 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,16 @@ class ConstIntBoundAnalyzer::Impl :
}
}
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
auto val = bound_->find(expr);
if (val != bound_->end()) {
auto everything = Everything(op->dtype);
auto everything = Everything(expr->dtype);
CHECK(
(val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
(val->second->min_value == everything.min_value &&
val->second->max_value == everything.max_value))
<< "Detected bound for " << expr << "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
(*bound_)[expr] = ConstIntBound(res.min_value, res.max_value);
}
return res;
}
Expand Down Expand Up @@ -369,7 +368,7 @@ class ConstIntBoundAnalyzer::Impl :
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
BoundMapType* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
Expand Down Expand Up @@ -563,7 +562,7 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
}

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
BoundMapType* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
Expand Down
7 changes: 3 additions & 4 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@ class DataTypeVisitor final : public StmtExprVisitor {
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
const PrimExprNode* op = e.as<PrimExprNode>();
if (bound_.find(op) == bound_.end()) {
if (bound_.find(e) == bound_.end()) {
analyzer_.const_int_bound(e, &bound_);
}
ConstIntBound bound = bound_[op];
ConstIntBound bound = bound_[e];
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
Expand Down Expand Up @@ -187,7 +186,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
// the memorized bound generated by ConstIntBoundAnalyzer
std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
arith::ConstIntBoundAnalyzer::BoundMapType bound_;
};

class DataTypeRewriter : public StmtExprMutator {
Expand Down