diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index deb8d3446fc1..a192fce6439a 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -370,7 +370,8 @@ def __init__(self, dom, var, iter_type, thread_tag=""): raise TypeError("dom need to be Range") name = var if var is not None else "iter" - var = Var(name, dtype="int32") if not isinstance(var, Var) else var + dtype = "int32" if dom is None else dom.extent.dtype + var = Var(name, dtype=dtype) if not isinstance(var, Var) else var self.__init_handle_by_constructor__( _ffi_api.IterVar, dom, var, iter_type, thread_tag) diff --git a/src/tir/pass/narrow_datatype.cc b/src/tir/pass/narrow_datatype.cc index d376e41d1e64..e6bc43abf7cb 100644 --- a/src/tir/pass/narrow_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -31,7 +31,7 @@ namespace tvm { namespace tir { // This pass narrows indexing expressions (like StoreNode::Index) -// that trivially fit into i32/i16 (denoted by `target_bits_`) to +// that trivially fit into i32/i16 (denoted by `target_bits_`) to // i32/i16. Considering that i32/i16 indices may be more // efficient on some backends (while i64 may be more efficient // on others, like llvm), we may want this pass when i32/i16 @@ -62,7 +62,7 @@ using arith::ConstIntBound; // // Algorithm: // We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`. -// To be more specific, if for each Expr `e` which contains `var` +// To be more specific, if for each Expr `e` which contains `var` // (`var` is a child node of `e` in AST), `e` fits into `target_bits_`, // then we narrow `var` into `target_bits_`. That is, // `vmap[var] = min(target_bits_, var.dtype.bits())` diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py index 85b3f63261dd..96a8ef41f793 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -157,6 +157,7 @@ def check(m, target_bits, target_dtype): def test_slice(): def check(m, n, target_bits, target_dtype): + # The index may overflow in B, while not in A ib = tvm.tir.ir_builder.create() Ab = tvm.tir.decl_buffer((m, n), name='A') A = ib.buffer_ptr(Ab)