Skip to content

Commit

Permalink
Fix sanity
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Apr 1, 2020
1 parent 9c5acee commit 6cee7b7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ def __init__(self, dom, var, iter_type, thread_tag=""):
raise TypeError("dom need to be Range")

name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var
dtype = "int32" if dom is None else dom.extent.dtype
var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)

Expand Down
4 changes: 2 additions & 2 deletions src/tir/pass/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace tvm {
namespace tir {

// This pass narrows indexing expressions (like StoreNode::Index)
// that trivially fit into i32/i16 (denoted by `target_bits_`) to
// that trivially fit into i32/i16 (denoted by `target_bits_`) to
// i32/i16. Considering that i32/i16 indices may be more
// efficient on some backends (while i64 may be more efficient
// on others, like llvm), we may want this pass when i32/i16
Expand Down Expand Up @@ -62,7 +62,7 @@ using arith::ConstIntBound;
//
// Algorithm:
// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
// To be more specific, if for each Expr `e` which contains `var`
// To be more specific, if for each Expr `e` which contains `var`
// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
// then we narrow `var` into `target_bits_`. That is,
// `vmap[var] = min(target_bits_, var.dtype.bits())`
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_tir_pass_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def check(m, target_bits, target_dtype):

def test_slice():
def check(m, n, target_bits, target_dtype):
# The index may overflow in B, while not in A
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Expand Down

0 comments on commit 6cee7b7

Please sign in to comment.