Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Cherry-Pick][Fix] Fix IndexDataTypeNormalizer (apache/tvm#13449) #355

Conversation

MasterJH5574
Copy link
Collaborator

This PR is a cherry-pick of apache/tvm#13449, which fixed a bug of IndexDataTypeNormalizer so that there will be no useless i64-to-i64 var casting in the PrimFunc generated by CreatePrimFunc.

This PR fixes the behavior of IndexDataTypeNormalizer on CastNode.

Consider the following case,
```python
A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
func = te.create_prim_func([A, B], index_dtype_override=None)
```
the generated PrimFunc is
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(4, 2), "float32"]):
    for i0, i1 in T.grid(4, 2):
        with T.block("T_reshape"):
            ax0, ax1 = T.axis.remap("SS", [i0, i1])
            T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)])
            T.writes(T_reshape[ax0, ax1])
            T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the output buffer is in int32. Other other hand, the input buffer has shape in int64. So as the script above shows, CreatePrimFunc will cast the int32 variables to int64 first, and access the input buffer afterwards.

Now if we use the option `index_dtype_override` to specify an index dtype as below,
```python
func = te.create_prim_func([A, B], index_dtype_override="int64")
```
the generated function will be
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"]):
    for i0, i1 in T.grid(T.int64(4), T.int64(2)):
        with T.block("T_reshape"):
            ax0, ax1 = T.axis.remap("SS", [i0, i1])
            T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)])
            T.writes(T_reshape[ax0, ax1])
            T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Note that though all variables and the buffer shapes have dtype int64, there are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an int64 variable. We don’t want such redundant casting.

To fix the issue above, this PR overrides the `VisitExpr_(const CastNode* cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode already has the target dtype, we no longer cast it.
@MasterJH5574 MasterJH5574 force-pushed the relax-dev/2023-01-11-index-dtype-normalizer-cp branch from 15d5b05 to f4e1ca9 Compare January 11, 2023 19:34
@tqchen tqchen merged commit ef33165 into tlc-pack:relax Jan 11, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants