Skip to content

Commit

Permalink
Cache PrimExpr instead of raw pointers in bound analyzer (#5533)
Browse files Browse the repository at this point in the history
The objects that the raw pointers point to can be deallocated and new
objects can be allocated at the same address, all while these pointers
are still in the cache. This can lead to unexpected behavior, for
example to calculated bound conflicts with previously cached values.

Caching PrimExpr will prevent the objects from being deallocated while
the cache is active.
  • Loading branch information
Krzysztof Parzyszek authored May 7, 2020
1 parent e40b8bc commit f05b911
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
4 changes: 2 additions & 2 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ConstIntBound : public ObjectRef {
*/
class ConstIntBoundAnalyzer {
public:
using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectHash, ObjectEqual>;
/*!
* \brief analyze the expr
* \param expr The expression of interest.
Expand All @@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer {
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);

/*!
* \brief Update constant int bound information of var.
Expand Down
11 changes: 5 additions & 6 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,16 @@ class ConstIntBoundAnalyzer::Impl :
}
}
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
auto val = bound_->find(expr);
if (val != bound_->end()) {
auto everything = Everything(op->dtype);
auto everything = Everything(expr->dtype);
CHECK(
(val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
(val->second->min_value == everything.min_value &&
val->second->max_value == everything.max_value))
<< "Detected bound for " << expr << "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
(*bound_)[expr] = ConstIntBound(res.min_value, res.max_value);
}
return res;
}
Expand Down Expand Up @@ -369,7 +368,7 @@ class ConstIntBoundAnalyzer::Impl :
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
BoundMapType* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
Expand Down Expand Up @@ -563,7 +562,7 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
}

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
BoundMapType* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
Expand Down
7 changes: 3 additions & 4 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@ class DataTypeVisitor final : public StmtExprVisitor {
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
const PrimExprNode* op = e.as<PrimExprNode>();
if (bound_.find(op) == bound_.end()) {
if (bound_.find(e) == bound_.end()) {
analyzer_.const_int_bound(e, &bound_);
}
ConstIntBound bound = bound_[op];
ConstIntBound bound = bound_[e];
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
Expand Down Expand Up @@ -187,7 +186,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
// the memorized bound generated by ConstIntBoundAnalyzer
std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
arith::ConstIntBoundAnalyzer::BoundMapType bound_;
};

class DataTypeRewriter : public StmtExprMutator {
Expand Down

0 comments on commit f05b911

Please sign in to comment.