From f4e1ca93132eaec6c528e9139bf889747b965f2e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 20 Nov 2022 22:40:55 -0500 Subject: [PATCH] [Cherry-Pick][Fix] Fix IndexDataTypeNormalizer (apache/tvm#13449) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes the behavior of IndexDataTypeNormalizer on CastNode. Consider the following case, ```python A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)), name="A") B = topi.reshape(A, (4, 2)) func = te.create_prim_func([A, B], index_dtype_override=None) ``` the generated PrimFunc is ```python @T.prim_func def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(4, 2), "float32"]): for i0, i1 in T.grid(4, 2): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)] ``` Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the output buffer is in int32. Other other hand, the input buffer has shape in int64. So as the script above shows, CreatePrimFunc will cast the int32 variables to int64 first, and access the input buffer afterwards. Now if we use the option `index_dtype_override` to specify an index dtype as below, ```python func = te.create_prim_func([A, B], index_dtype_override="int64") ``` the generated function will be ```python @T.prim_func def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"]): for i0, i1 in T.grid(T.int64(4), T.int64(2)): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)] ``` Note that though all variables and the buffer shapes have dtype int64, there are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an int64 variable. We don’t want such redundant casting. To fix the issue above, this PR overrides the `VisitExpr_(const CastNode* cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode already has the target dtype, we no longer cast it. --- include/tvm/tir/data_type_rewriter.h | 1 + src/te/operation/create_primfunc.cc | 2 +- src/tir/ir/data_type_rewriter.cc | 9 ++ .../unittest/test_te_create_primfunc.py | 106 ++++++------------ 4 files changed, 45 insertions(+), 73 deletions(-) diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 378addaba5..bf90aaedfe 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter { using Parent::VisitStmt_; PrimExpr VisitExpr_(const IntImmNode* op) final; PrimExpr VisitExpr_(const VarNode* op) final; + PrimExpr VisitExpr_(const CastNode* op) final; DataType target_data_type_ = DataType::Int(64); }; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 936aa389d7..75f02be0aa 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -512,7 +512,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, const Optional>& tir_var_list, std::optional index_dtype_override) { - // Infomations used in CreatePrimFunc and its sub-functions. + // Information used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. Array root_stmts; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 1a017ce831..27a59d9709 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -511,6 +511,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) { IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type) : target_data_type_(std::move(target_data_type)) {} + PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) { Map new_buffer_map = func->buffer_map; for (const auto& [var, buffer] : func->buffer_map) { @@ -542,5 +543,13 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { return GetRef(op); } +PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) { + if (is_enabled_) { + PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value); + return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value); + } + return IndexDataTypeRewriter::VisitExpr_(op); +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 22adb72ea9..8059d02409 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -46,8 +46,6 @@ def test_unique_name_reduction_block(): def _check_workload(te_workload, tir_workload, index_dtype_override=None): func = te.create_prim_func(te_workload(), index_dtype_override) - print(func.script()) - print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload)) tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func s = tir.Schedule(func, debug_mask="all") @@ -598,6 +596,39 @@ def expected( _check_workload(te_func, expected) +def te_reshape(): + # The following is possible to be generated by TOPI. So we test this case. + A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A") + B = topi.reshape(A, (4, 2)) + return [A, B] + + +@T.prim_func +def tir_reshape( + A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], + T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"], +): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1 in T.grid(T.int64(4), T.int64(2)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + A[ + (ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4), + (ax0 * T.int64(2) + ax1) % T.int64(4), + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = A[ + (ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4), + (ax0 * T.int64(2) + ax1) % T.int64(4), + ] + + +def test_reshape(): + _check_workload(te_reshape, tir_reshape, index_dtype_override="int64") + + def test_unbound_var(): n = tir.Var("n", "int32") A = te.placeholder((n + 1,), name="A") @@ -614,75 +645,6 @@ def test_unbound_var(): tvm.testing.assert_allclose(a_np, b.numpy()) -def te_argmax(): - # x and y are the operands of reduction, both of them is a tuple of index - # and value. - def fcombine(x, y): - lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) - rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) - return lhs, rhs - - # our identity element also need to be a tuple, so `fidentity` accepts - # two types as inputs. - def fidentity(t0, t1): - return tvm.tir.const(-1, t0), tvm.te.min_value(t1) - - argmax = te.comm_reducer(fcombine, fidentity, name="argmax") - - # describe the reduction computation - m = te.var("m") - n = te.var("n") - idx = te.placeholder((m, n), name="idx", dtype="int32") - val = te.placeholder((m, n), name="val", dtype="int32") - k = te.reduce_axis((0, n), "k") - T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T") - return [idx, val, T0, T1] - - -@T.prim_func -def tir_argmax( - var_idx: T.handle, var_val: T.handle, var_T_v0: T.handle, var_T_v1: T.handle -) -> None: - m = T.var("int32") - n = T.var("int32") - idx = T.match_buffer(var_idx, [m, n], dtype="int32") - val = T.match_buffer(var_val, [m, n], dtype="int32") - T_v0 = T.match_buffer(var_T_v0, [m], dtype="int32") - T_v1 = T.match_buffer(var_T_v1, [m], dtype="int32") - # body - # with T.block("root") - for i0, i1 in T.grid(m, n): - with T.block("T.v0"): - i, k = T.axis.remap("SR", [i0, i1]) - with T.init(): - T_v0[i] = -1 - T_v1[i] = -2147483648 - T_v0[i] = T.Select(T_v1[i] >= val[i, k], T_v0[i], idx[i, k]) - T_v1[i] = T.Select(T_v1[i] >= val[i, k], T_v1[i], val[i, k]) - - -def test_argmax(): - _check_workload(te_argmax, tir_argmax) - - dtype = "int32" - func = te.create_prim_func(te_argmax()) - assert len(func.params) == 4 - - func = tvm.build(func) - - idx_np = np.arange(100, dtype=dtype).reshape((10, 10)) - val_np = np.random.permutation(100).reshape((10, 10)).astype(dtype) - c = tvm.nd.array(np.zeros(10, dtype=dtype)) # argmax index - d = tvm.nd.array(np.zeros(10, dtype=dtype)) # max value - func(tvm.nd.array(idx_np), tvm.nd.array(val_np), c, d) - - c_expected = idx_np[np.arange(10), np.argmax(val_np, axis=1)] - d_expected = np.amax(val_np, axis=1) - - tvm.testing.assert_allclose(c_expected, c.numpy()) - tvm.testing.assert_allclose(d_expected, d.numpy()) - - if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -701,6 +663,6 @@ def test_argmax(): test_argmax_val_idx() test_int64_indices() test_zero_dim_add() + test_reshape() test_loop_var_datatype() test_unbound_var() - test_argmax()