Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Mar 20, 2020
1 parent c2eec6d commit 0724b66
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions src/tir/pass/rewrite_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

/*!
* \file rewrite_datatype.cc
* \brief narrow the datatype of indexing vars
*/

#include <tvm/tir/ir_pass.h>
Expand Down Expand Up @@ -57,8 +58,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
vset_.insert(op->loop_var.as<Object>());
vextent_[op->loop_var.as<Object>()] = op->extent.dtype();
vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
return StmtExprVisitor::VisitStmt_(op);
}

Expand All @@ -69,8 +69,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var,
Range::make_by_min_extent(0, op->value));
vset_.insert(iv->var.as<Object>());
vextent_[iv->var.as<Object>()] = op->value.dtype();
vextent_[iv->var.as<VarNode>()] = op->value.dtype();
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
Expand All @@ -81,15 +80,14 @@ class DataTypeVisitor final : public StmtExprVisitor {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
vset_.insert(iv->var.as<Object>());
vextent_[iv->var.as<Object>()] = iv->dom->extent.dtype();
vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype();
}
// Recursively call simplification when necessary.
StmtExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const VarNode* op) {
if (vset_.find(op) != vset_.end()) {
if (vextent_.find(op) != vextent_.end()) {
int bits = std::min(vextent_[op].bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
Expand Down Expand Up @@ -125,19 +123,17 @@ class DataTypeVisitor final : public StmtExprVisitor {
}

// the narrowed datatype of Var and IntImm
std::unordered_map<const Object*, DataType> vmap;
std::unordered_map<const PrimExprNode*, DataType> vmap;

protected:
// internal analyzer
arith::Analyzer analyzer_;

private:
// the maximum bits of all containing expressions
// the maximum possible bit of the current expression's return dtype
int bits_;
// the vars to be rewritten
std::unordered_set<const Object*> vset_;
// the extent of vars to be rewritten
std::unordered_map<const Object*, DataType> vextent_;
std::unordered_map<const VarNode*, DataType> vextent_;
};

class DataTypeRewriter : public StmtExprMutator {
Expand Down

0 comments on commit 0724b66

Please sign in to comment.