From 4e84153395dc53b3e76275aa9fe149bde1c6359e Mon Sep 17 00:00:00 2001
From: Haozheng Fan <fanhaozh@amazon.com>
Date: Fri, 20 Mar 2020 23:30:13 +0800
Subject: [PATCH] Resolve comments

---
 include/tvm/tir/ir_pass.h                     |  2 +-
 python/tvm/driver/build_module.py             |  2 +-
 src/tir/pass/ffi_api.cc                       |  2 +-
 ...rewrite_datatype.cc => narrow_datatype.cc} | 41 ++++++++++++++++++-
 ...pe.py => test_tir_pass_narrow_datatype.py} |  2 +-
 5 files changed, 43 insertions(+), 6 deletions(-)
 rename src/tir/pass/{rewrite_datatype.cc => narrow_datatype.cc} (88%)
 rename tests/python/unittest/{test_tir_pass_rewrite_datatype.py => test_tir_pass_narrow_datatype.py} (98%)

diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h
index e61c91a6a30bf..4c54ae49fee89 100644
--- a/include/tvm/tir/ir_pass.h
+++ b/include/tvm/tir/ir_pass.h
@@ -390,7 +390,7 @@ Stmt HoistIfThenElse(Stmt stmt);
  * \param stmt The stmt to do datatype rewrite
  * \return Transformed stmt.
  */
-Stmt DataTypeRewrite(Stmt stmt);
+Stmt NarrowDataType(Stmt stmt);
 
 /*!
  * \brief Make an user callable API LoweredFunc.
diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py
index 220de8eb12826..7ef5565fbd4a4 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -159,7 +159,7 @@ def lower(sch,
     # Phase 1
     stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
     stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
-    stmt = ir_pass.DataTypeRewrite(stmt)
+    stmt = ir_pass.NarrowDataType(stmt)
     stmt = ir_pass.CanonicalSimplify(stmt)
     for f in lower_phase1:
         stmt = f(stmt)
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
index 5e921ae0155ba..39524c4a2eb57 100644
--- a/src/tir/pass/ffi_api.cc
+++ b/src/tir/pass/ffi_api.cc
@@ -177,6 +177,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
 REGISTER_PASS(VerifyCompactBuffer);
 REGISTER_PASS(HoistIfThenElse);
 REGISTER_PASS(InferFragment)
-REGISTER_PASS(DataTypeRewrite);
+REGISTER_PASS(NarrowDataType);
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/narrow_datatype.cc
similarity index 88%
rename from src/tir/pass/rewrite_datatype.cc
rename to src/tir/pass/narrow_datatype.cc
index b1246edbfcd3c..bd7a8d40c1280 100644
--- a/src/tir/pass/rewrite_datatype.cc
+++ b/src/tir/pass/narrow_datatype.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file rewrite_datatype.cc
+ * \file narrow_datatype.cc
  * \brief narrow the datatype of indexing vars
  */
 
@@ -30,6 +30,28 @@
 namespace tvm {
 namespace tir {
 
+// This pass narrows indexing expressions (like StoreNode::Index)
+// that trivially fit into i32 to i32. Considering that i32 indices
+// may be more efficient on some backends (while i64 may be more
+// efficient on others, like llvm), we may want this pass when i32
+// indices are more efficient.
+//
+// For Var v, we determine its dtype by examining all the PrimExpr
+// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
+// If all expressions in E fit into i32, then we think v can be narrowed
+// to i32.
+//
+// To make an indexing expression i32, we must make sure that every
+// component of that expression is of dtype i32. So besides Var, we
+// rewrite the following inside an indexing expression
+// - Var
+// - IntImm
+// - Cast
+//
+// Algorithm:
+// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
+// - Use DataTypeRewritter to rewrite the components of an indexing expression.
+
 using arith::Analyzer;
 using arith::IRMutatorWithAnalyzer;
 using arith::ConstIntBound;
@@ -166,6 +188,9 @@ class DataTypeRewriter : public StmtExprMutator {
   Stmt VisitStmt_(const ForNode* op) final {
     Stmt s = StmtExprMutator::VisitStmt_(op);
     op = s.as<ForNode>();
+    CHECK(op != nullptr)
+      << "Expected type to be ForNode"
+      << ", but get " << s->GetTypeKey();
     PrimExpr e = VisitExpr(op->loop_var);
     Var var = Downcast<Var, PrimExpr>(e);
     return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
@@ -177,7 +202,13 @@ class DataTypeRewriter : public StmtExprMutator {
         op->attr_key == attr::virtual_thread) {
       Stmt s = StmtExprMutator::VisitStmt_(op);
       op = s.as<AttrStmtNode>();
+      CHECK(op != nullptr)
+        << "Expected type to be AttrStmtNode"
+        << ", but get " << s->GetTypeKey();
       const IterVarNode* iv = op->node.as<IterVarNode>();
+      CHECK(iv != nullptr)
+        << "Expected type to be IterVarNode"
+        << ", but get " << op->node->GetTypeKey();
       PrimExpr e = VisitExpr(iv->var);
       Var var = Downcast<Var, PrimExpr>(e);
       if (ivmap_.find(iv) == ivmap_.end()) {
@@ -233,6 +264,9 @@ class DataTypeRewriter : public StmtExprMutator {
     if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
       PrimExpr e = StmtExprMutator::VisitExpr_(op);
       const CastNode* new_op = e.as<CastNode>();
+      CHECK(new_op != nullptr)
+        << "Expected type to be CastNode"
+        << ", but get " << e->GetTypeKey();
       return CastNode::make(visitor_.vmap[op], new_op->value);
     }
     return StmtExprMutator::VisitExpr_(op);
@@ -298,6 +332,9 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=)
 PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
   PrimExpr e = StmtExprMutator::VisitExpr_(op);
   op = e.as<CallNode>();
+  CHECK(op != nullptr)
+    << "Expected type to be CallNode"
+    << ", but get " << e->GetTypeKey();
   if (op->call_type == CallNode::PureIntrinsic) {
     if (op->name == intrinsic::tvm_if_then_else) {
       return if_then_else(op->args[0], op->args[1], op->args[2]);
@@ -318,7 +355,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
   return e;
 }
 
-Stmt DataTypeRewrite(Stmt stmt) {
+Stmt NarrowDataType(Stmt stmt) {
   return DataTypeRewriter()(stmt);
 }
 
diff --git a/tests/python/unittest/test_tir_pass_rewrite_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py
similarity index 98%
rename from tests/python/unittest/test_tir_pass_rewrite_datatype.py
rename to tests/python/unittest/test_tir_pass_narrow_datatype.py
index 69eee8cc19d6e..297dc166aa6a0 100644
--- a/tests/python/unittest/test_tir_pass_rewrite_datatype.py
+++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py
@@ -32,7 +32,7 @@ def lower(sch, args):
     bounds = te.schedule.InferBound(sch)
     stmt = te.schedule.ScheduleOps(sch, bounds)
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
-    stmt = tvm.tir.ir_pass.DataTypeRewrite(stmt)
+    stmt = tvm.tir.ir_pass.NarrowDataType(stmt)
     return stmt