Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Fix] Fix IndexDataTypeNormalizer to avoid redundant casting (#13449)
This PR fixes the behavior of IndexDataTypeNormalizer on CastNode. ## Background Consider the following case, ```python A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)), name="A") B = topi.reshape(A, (4, 2)) func = te.create_prim_func([A, B], index_dtype_override=None) ``` the generated PrimFunc is ```python @T.prim_func def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(4, 2), "float32"]): for i0, i1 in T.grid(4, 2): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)] ``` Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the output buffer is in int32. Other other hand, the input buffer has shape in int64. So as the script above shows, CreatePrimFunc will cast the int32 variables to int64 first, and access the input buffer afterwards. Now if we use the option `index_dtype_override` to specify an index dtype as below, ```python func = te.create_prim_func([A, B], index_dtype_override="int64") ``` the generated function will be ```python @T.prim_func def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"]): for i0, i1 in T.grid(T.int64(4), T.int64(2)): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)] ``` Note that though all variables and the buffer shapes have dtype int64, there are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an int64 variable. We don’t want such redundant casting. ## Fix To fix the issue above, this PR overrides the `VisitExpr_(const CastNode* cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode already has the target dtype, we no longer cast it.
- Loading branch information