diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index c341fc34b1681..39b39d28202ec 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -389,9 +389,10 @@ Stmt HoistIfThenElse(Stmt stmt); /*! * \brief Narrow down PrimExpr datatype in stmt * \param stmt The stmt to do datatype rewrite + * \param target_bits the bit of target datatype * \return Transformed stmt. */ -Stmt NarrowDataType(Stmt stmt); +Stmt NarrowDataType(Stmt stmt, int target_bits); /*! * \brief Make an user callable API LoweredFunc. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 7ef5565fbd4a4..88231aaba56cf 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,7 +159,7 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.NarrowDataType(stmt) + stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) diff --git a/src/tir/pass/narrow_datatype.cc b/src/tir/pass/narrow_datatype.cc index bd7a8d40c1280..863edea0199d8 100644 --- a/src/tir/pass/narrow_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -56,22 +56,25 @@ using arith::Analyzer; using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; -class DataTypeRewriter; - class DataTypeVisitor final : public StmtExprVisitor { public: + explicit DataTypeVisitor(int target_bits) + : bits_(target_bits), target_bits_(target_bits) {} + void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { - int bits = 64; - if (e.dtype().bits() <= 32 || - analyzer_.CanProve(e <= max_value(DataType::Int(32)) && - e >= min_value(DataType::Int(32)))) { - bits = 32; + int bits = max_bits_; + ConstIntBound bound = analyzer_.const_int_bound(e); + int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; + int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; + if (e.dtype().bits() <= target_bits_ || + (bound->max_value <= ubound && bound->min_value >= lbound)) { + bits = target_bits_; } - int tmp = bits_; - bits_ = bits > bits_ ? bits : bits_; + int tmp = bits > bits_ ? bits : bits_; + std::swap(bits_, tmp); StmtExprVisitor::VisitExpr(e); - bits_ = tmp; + std::swap(bits_, tmp); } else { StmtExprVisitor::VisitExpr(e); } @@ -152,14 +155,20 @@ class DataTypeVisitor final : public StmtExprVisitor { arith::Analyzer analyzer_; private: + // the maximum possible bits, which serves as an init value + static constexpr const int max_bits_ = 64; // the maximum possible bit of the current expression's return dtype int bits_; + // the target bits + int target_bits_; // the extent of vars to be rewritten std::unordered_map vextent_; }; class DataTypeRewriter : public StmtExprMutator { public: + explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {} + Stmt operator()(Stmt s) { visitor_(s); for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { @@ -298,6 +307,7 @@ class DataTypeRewriter : public StmtExprMutator { // a map from IterVar before rewrite to that after rewrite, // ensures one old IterVar maps to exactly one new IterVar std::unordered_map ivmap_; + // indicator of LoadNode::index and StoreNode::index bool is_index_{false}; }; @@ -355,8 +365,8 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt NarrowDataType(Stmt stmt) { - return DataTypeRewriter()(stmt); +Stmt NarrowDataType(Stmt stmt, int target_bits) { + return DataTypeRewriter(target_bits)(stmt); } } // namespace tir diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py index 297dc166aa6a0..539decc540531 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -32,7 +32,7 @@ def lower(sch, args): bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - stmt = tvm.tir.ir_pass.NarrowDataType(stmt) + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32) return stmt