Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Apr 1, 2020
1 parent 0578b85 commit 9c5acee
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 33 deletions.
40 changes: 29 additions & 11 deletions src/tir/pass/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,19 @@ namespace tvm {
namespace tir {

// This pass narrows indexing expressions (like StoreNode::Index)
// that trivially fit into i32 to i32. Considering that i32 indices
// may be more efficient on some backends (while i64 may be more
// efficient on others, like llvm), we may want this pass when i32
// that trivially fit into i32/i16 (denoted by `target_bits_`) to
// i32/i16. Considering that i32/i16 indices may be more
// efficient on some backends (while i64 may be more efficient
// on others, like llvm), we may want this pass when i32/i16
// indices are more efficient.
//
// For Var v, we determine its dtype by examining all the PrimExpr
// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
// If all expressions in E fit into i32, then we think v can be narrowed
// to i32.
// If all expressions in E fit into i32/i16, then we think v can be narrowed
// to i32/i16.
//
// To make an indexing expression i32, we must make sure that every
// component of that expression is of dtype i32. So besides Var, we
// To make an indexing expression i32/i16, we must make sure that every
// component of that expression is of dtype i32/i16. So besides Var, we
// rewrite the following inside an indexing expression
// - Var
// - IntImm
Expand All @@ -56,6 +57,16 @@ using arith::Analyzer;
using arith::IRMutatorWithAnalyzer;
using arith::ConstIntBound;

// Determine the result dtype for Var, IntImm and Cast,
// which will be stored in `vmap` eventually.
//
// Algorithm:
// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
// To be more specific, if for each Expr `e` which contains `var`
// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
// then we narrow `var` into `target_bits_`. That is,
// `vmap[var] = min(target_bits_, var.dtype.bits())`
// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
class DataTypeVisitor final : public StmtExprVisitor {
public:
explicit DataTypeVisitor(int target_bits)
Expand All @@ -65,8 +76,8 @@ class DataTypeVisitor final : public StmtExprVisitor {
if (e.dtype().is_int()) {
int bits = max_bits_;
ConstIntBound bound = analyzer_.const_int_bound(e);
int64_t ubound = Downcast<IntImm, PrimExpr>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm, PrimExpr>(min_value(DataType::Int(target_bits_)))->value;
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
(bound->max_value <= ubound && bound->min_value >= lbound)) {
bits = target_bits_;
Expand Down Expand Up @@ -113,10 +124,13 @@ class DataTypeVisitor final : public StmtExprVisitor {

void VisitExpr_(const VarNode* op) {
if (vextent_.find(op) != vextent_.end()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(vextent_[op].bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
// We take maximum bits for all the possible Expr where a var occurs
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
Expand All @@ -125,6 +139,8 @@ class DataTypeVisitor final : public StmtExprVisitor {

void VisitExpr_(const IntImmNode* op) {
if (op->dtype.is_int()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(op->dtype.bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
Expand All @@ -137,6 +153,8 @@ class DataTypeVisitor final : public StmtExprVisitor {

void VisitExpr_(const CastNode* op) {
if (op->dtype.is_int()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(op->dtype.bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
Expand Down Expand Up @@ -201,7 +219,7 @@ class DataTypeRewriter : public StmtExprMutator {
<< "Expected type to be ForNode"
<< ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var, PrimExpr>(e);
Var var = Downcast<Var>(e);
return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
op->for_type, op->device_api, op->body);
}
Expand All @@ -219,7 +237,7 @@ class DataTypeRewriter : public StmtExprMutator {
<< "Expected type to be IterVarNode"
<< ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var, PrimExpr>(e);
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag);
}
Expand Down
93 changes: 71 additions & 22 deletions tests/python/unittest/test_tir_pass_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
import tvm
from tvm import te
from tvm.tir import const


def lower(sch, args):
def lower(sch, args, target_bits):
binds = {}
arg_list = []
for x in args:
Expand All @@ -32,7 +33,7 @@ def lower(sch, args):
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
return stmt


Expand All @@ -52,10 +53,16 @@ def check(m, n, target_bits, target_dtype):
assert stmt.body.loop_var.dtype == target_dtype

# const shape
# i32 -> i32
check(2, 2, 32, "int32")
check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow
check(tvm.tir.const(2, dtype='int64'), tvm.tir.const(2, dtype='int64'), 32, "int32")
check(tvm.tir.const(2**16, dtype='int64'), tvm.tir.const(2**16, dtype='int64'), 32, "int64")
# i64 -> i32
check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32")
check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64")
# i32 -> i16
check(2, 2, 16, "int16")
check(2**10, 2**10, 16, "int32")

# symbolic shape
check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32")
check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64")
Expand All @@ -78,18 +85,23 @@ def check(m, n, target_bits, target_dtype):
assert stmt.node.var.dtype == target_dtype
assert stmt.body.node.var.dtype == target_dtype


# i32 -> i32
check(2, 32,
target_bits=32, target_dtype='int32')
# i32 + i32 is not promoted to i64 even in the case of overflow
check(2**30, 32,
check(2**30, 32, # i32 + i32 is not promoted to i64 even in the case of overflow
target_bits=32, target_dtype='int32')
check(tvm.tir.const(2, dtype='int64'),
tvm.tir.const(32, dtype='int64'),
# i64 -> i32
check(const(2, dtype='int64'),
const(32, dtype='int64'),
target_bits=32, target_dtype='int32')
check(tvm.tir.const(2**30, dtype='int64'),
tvm.tir.const(32, dtype='int64'),
check(const(2**30, dtype='int64'),
const(32, dtype='int64'),
target_bits=32, target_dtype='int64')
# i32 -> i16
check(2, 32,
target_bits=16, target_dtype='int16')
check(2**14, 32,
target_bits=16, target_dtype='int32')


def test_multilanes():
Expand All @@ -105,33 +117,70 @@ def check(m, lanes, target_bits, target_dtype):
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype

check(tvm.tir.const(2 ** 10, dtype='int32'), 2,
# i32 -> i32
check(const(2 ** 10, dtype='int32'), 2,
target_bits=32, target_dtype='int32')
check(tvm.tir.const(2 ** 32, dtype='int32'), 2,
check(const(2 ** 32, dtype='int32'), 2,
target_bits=32, target_dtype='int32')
check(tvm.tir.const(2 ** 10, dtype='int64'), 2,
# i64 -> i32
check(const(2 ** 10, dtype='int64'), 2,
target_bits=32, target_dtype='int32')
check(tvm.tir.const(2 ** 32, dtype='int64'), 2,
check(const(2 ** 32, dtype='int64'), 2,
target_bits=32, target_dtype='int64')
# i32 -> i16
check(const(2 ** 10, dtype='int32'), 2,
target_bits=16, target_dtype='int16')
check(const(2 ** 16, dtype='int32'), 2,
target_bits=16, target_dtype='int32')


def test_reduce():
def check(m, dtype):
def check(m, target_bits, target_dtype):
A = te.placeholder((m,), name='A', dtype='float32')
k = te.reduce_axis((0, m), "k")
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
s = te.create_schedule(B.op)
stmt = lower(s, [A, B])
assert stmt.body[1].loop_var.dtype == dtype
stmt = lower(s, [A, B], target_bits)
assert stmt.body[1].loop_var.dtype == target_dtype

# i32 -> i32
check(const(64, dtype='int32'), 32, 'int32')
# i64 -> i32
check(const(64, dtype='int64'), 32, 'int32')
# i32 -> i16
check(const(64, dtype='int32'), 16, 'int16')
check(const(2**16, dtype='int32'), 16, 'int32')
# symbolic
check(te.var('n', dtype='int32'), 32, 'int32')
check(te.var('n', dtype='int64'), 32, 'int64')


def test_slice():
def check(m, n, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n * 2), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i') as i:
with ib.for_range(0, n, name='j') as j:
A[i * n + j] = B[i * 2 * n + 2 * j] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype

check(tvm.tir.const(64, dtype='int32'), 'int32')
check(tvm.tir.const(64, dtype='int64'), 'int32')
check(te.var('n', dtype='int32'), 'int32')
check(te.var('n', dtype='int64'), 'int64')
# The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
check(const(2**15, 'int64'), const(2**15, 'int64'),
target_bits=32, target_dtype='int32')
# The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
check(const(2**15, 'int64'), const((2**15 + 1), 'int64'),
target_bits=32, target_dtype='int64')


if __name__ == "__main__":
test_basic()
test_thread_axis()
test_multilanes()
test_reduce()
test_slice()

0 comments on commit 9c5acee

Please sign in to comment.