Skip to content

Commit

Permalink
[Fix] Fix IndexDataTypeNormalizer to avoid redundant casting (apache/…
Browse files Browse the repository at this point in the history
…tvm#13449)

This PR fixes the behavior of IndexDataTypeNormalizer on CastNode.

Consider the following case,
```python
A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
func = te.create_prim_func([A, B], index_dtype_override=None)
```
the generated PrimFunc is
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(4, 2), "float32"]):
    for i0, i1 in T.grid(4, 2):
        with T.block("T_reshape"):
            ax0, ax1 = T.axis.remap("SS", [i0, i1])
            T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)])
            T.writes(T_reshape[ax0, ax1])
            T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the output buffer is in int32. Other other hand, the input buffer has shape in int64. So as the script above shows, CreatePrimFunc will cast the int32 variables to int64 first, and access the input buffer afterwards.

Now if we use the option `index_dtype_override` to specify an index dtype as below,
```python
func = te.create_prim_func([A, B], index_dtype_override="int64")
```
the generated function will be
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"]):
    for i0, i1 in T.grid(T.int64(4), T.int64(2)):
        with T.block("T_reshape"):
            ax0, ax1 = T.axis.remap("SS", [i0, i1])
            T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)])
            T.writes(T_reshape[ax0, ax1])
            T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Note that though all variables and the buffer shapes have dtype int64, there are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an int64 variable. We don’t want such redundant casting.

To fix the issue above, this PR overrides the `VisitExpr_(const CastNode* cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode already has the target dtype, we no longer cast it.
  • Loading branch information
MasterJH5574 committed Jan 11, 2023
1 parent 64ec6a5 commit e6df7e6
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 73 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const IntImmNode* op) final;
PrimExpr VisitExpr_(const VarNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

DataType target_data_type_ = DataType::Int(64);
};
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants,
const Optional<Array<tir::Var>>& tir_var_list,
std::optional<DataType> index_dtype_override) {
// Infomations used in CreatePrimFunc and its sub-functions.
// Information used in CreatePrimFunc and its sub-functions.
CreateFuncInfo info(arg_list);
// Root body stmts.
Array<Stmt> root_stmts;
Expand Down
9 changes: 9 additions & 0 deletions src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {

IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
: target_data_type_(std::move(target_data_type)) {}

PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
Expand Down Expand Up @@ -542,5 +543,13 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
return GetRef<PrimExpr>(op);
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
if (is_enabled_) {
PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
}
return IndexDataTypeRewriter::VisitExpr_(op);
}

} // namespace tir
} // namespace tvm
106 changes: 34 additions & 72 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def test_unique_name_reduction_block():

def _check_workload(te_workload, tir_workload, index_dtype_override=None):
func = te.create_prim_func(te_workload(), index_dtype_override)
print(func.script())
print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload))
tvm.ir.assert_structural_equal(func, tir_workload)
# make sure that we can create schedule from the func
s = tir.Schedule(func, debug_mask="all")
Expand Down Expand Up @@ -598,6 +596,39 @@ def expected(
_check_workload(te_func, expected)


def te_reshape():
# The following is possible to be generated by TOPI. So we test this case.
A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
return [A, B]


@T.prim_func
def tir_reshape(
A: T.Buffer[(T.int64(2), T.int64(4)), "float32"],
T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1 in T.grid(T.int64(4), T.int64(2)):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(
A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]
)
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]


def test_reshape():
_check_workload(te_reshape, tir_reshape, index_dtype_override="int64")


def test_unbound_var():
n = tir.Var("n", "int32")
A = te.placeholder((n + 1,), name="A")
Expand All @@ -614,75 +645,6 @@ def test_unbound_var():
tvm.testing.assert_allclose(a_np, b.numpy())


def te_argmax():
# x and y are the operands of reduction, both of them is a tuple of index
# and value.
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs

# our identity element also need to be a tuple, so `fidentity` accepts
# two types as inputs.
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)

argmax = te.comm_reducer(fcombine, fidentity, name="argmax")

# describe the reduction computation
m = te.var("m")
n = te.var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="int32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
return [idx, val, T0, T1]


@T.prim_func
def tir_argmax(
var_idx: T.handle, var_val: T.handle, var_T_v0: T.handle, var_T_v1: T.handle
) -> None:
m = T.var("int32")
n = T.var("int32")
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
val = T.match_buffer(var_val, [m, n], dtype="int32")
T_v0 = T.match_buffer(var_T_v0, [m], dtype="int32")
T_v1 = T.match_buffer(var_T_v1, [m], dtype="int32")
# body
# with T.block("root")
for i0, i1 in T.grid(m, n):
with T.block("T.v0"):
i, k = T.axis.remap("SR", [i0, i1])
with T.init():
T_v0[i] = -1
T_v1[i] = -2147483648
T_v0[i] = T.Select(T_v1[i] >= val[i, k], T_v0[i], idx[i, k])
T_v1[i] = T.Select(T_v1[i] >= val[i, k], T_v1[i], val[i, k])


def test_argmax():
_check_workload(te_argmax, tir_argmax)

dtype = "int32"
func = te.create_prim_func(te_argmax())
assert len(func.params) == 4

func = tvm.build(func)

idx_np = np.arange(100, dtype=dtype).reshape((10, 10))
val_np = np.random.permutation(100).reshape((10, 10)).astype(dtype)
c = tvm.nd.array(np.zeros(10, dtype=dtype)) # argmax index
d = tvm.nd.array(np.zeros(10, dtype=dtype)) # max value
func(tvm.nd.array(idx_np), tvm.nd.array(val_np), c, d)

c_expected = idx_np[np.arange(10), np.argmax(val_np, axis=1)]
d_expected = np.amax(val_np, axis=1)

tvm.testing.assert_allclose(c_expected, c.numpy())
tvm.testing.assert_allclose(d_expected, d.numpy())


if __name__ == "__main__":
test_unique_name_complete_block()
test_unique_name_reduction_block()
Expand All @@ -701,6 +663,6 @@ def test_argmax():
test_argmax_val_idx()
test_int64_indices()
test_zero_dim_add()
test_reshape()
test_loop_var_datatype()
test_unbound_var()
test_argmax()

0 comments on commit e6df7e6

Please sign in to comment.