Skip to content

Commit

Permalink
[Fix] Fix IndexDataTypeNormalizer to avoid redundant casting (#13449)
Browse files Browse the repository at this point in the history
This PR fixes the behavior of IndexDataTypeNormalizer on CastNode.

## Background

Consider the following case,
```python
A = te.placeholder((tir.IntImm("int64", 2), tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
func = te.create_prim_func([A, B], index_dtype_override=None)
```
the generated PrimFunc is
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(4, 2), "float32"]):
    for i0, i1 in T.grid(4, 2):
        with T.block("T_reshape"):
            ax0, ax1 = T.axis.remap("SS", [i0, i1])
            T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)])
            T.writes(T_reshape[ax0, ax1])
            T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Here loop variables `ax0` and `ax1` have dtype int32, since the shape of the output buffer is in int32. Other other hand, the input buffer has shape in int64. So as the script above shows, CreatePrimFunc will cast the int32 variables to int64 first, and access the input buffer afterwards.

Now if we use the option `index_dtype_override` to specify an index dtype as below,
```python
func = te.create_prim_func([A, B], index_dtype_override="int64")
```
the generated function will be
```python
@T.prim_func
def func(A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"]):
    for i0, i1 in T.grid(T.int64(4), T.int64(2)):
        with T.block("T_reshape"):
            ax0, ax1 = T.axis.remap("SS", [i0, i1])
            T.reads(A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)])
            T.writes(T_reshape[ax0, ax1])
            T_reshape[ax0, ax1] = A[(T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(8) // T.int64(4), (T.Cast("int64", ax0) * T.int64(2) + T.Cast("int64", ax1)) % T.int64(4)]
```
Note that though all variables and the buffer shapes have dtype int64, there are still CastNodes such as `T.Cast("int64", ax0)` when `ax0` is already an int64 variable. We don’t want such redundant casting.

## Fix

To fix the issue above, this PR overrides the `VisitExpr_(const CastNode* cast)` method in IndexDataTypeNormalizer. When the `value` field of a CastNode already has the target dtype, we no longer cast it.
  • Loading branch information
MasterJH5574 authored Nov 21, 2022
1 parent 26d9b5a commit d663207
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const IntImmNode* op) final;
PrimExpr VisitExpr_(const VarNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

DataType target_data_type_ = DataType::Int(64);
};
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants,
std::optional<DataType> index_dtype_override) {
// Infomations used in CreatePrimFunc and its sub-functions.
// Information used in CreatePrimFunc and its sub-functions.
CreateFuncInfo info(arg_list);
// Root body stmts.
Array<Stmt> root_stmts;
Expand Down
11 changes: 10 additions & 1 deletion src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {

IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
: target_data_type_(std::move(target_data_type)) {}

PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
Expand All @@ -534,13 +535,21 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
return (*it).second;
}
if (is_enabled_) {
if (is_enabled_ && op->dtype != target_data_type_) {
Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
var_remap_.Set(GetRef<Var>(op), new_var);
return std::move(new_var);
}
return GetRef<PrimExpr>(op);
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
if (is_enabled_) {
PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
}
return IndexDataTypeRewriter::VisitExpr_(op);
}

} // namespace tir
} // namespace tvm
36 changes: 34 additions & 2 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def test_unique_name_reduction_block():

def _check_workload(te_workload, tir_workload, index_dtype_override=None):
func = te.create_prim_func(te_workload(), index_dtype_override)
print(func.script())
print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload))
tvm.ir.assert_structural_equal(func, tir_workload)
# make sure that we can create schedule from the func
s = tir.Schedule(func, debug_mask="all")
Expand Down Expand Up @@ -575,6 +573,39 @@ def expected(
_check_workload(te_func, expected)


def te_reshape():
# The following is possible to be generated by TOPI. So we test this case.
A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
return [A, B]


@T.prim_func
def tir_reshape(
A: T.Buffer[(T.int64(2), T.int64(4)), "float32"],
T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1 in T.grid(T.int64(4), T.int64(2)):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(
A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]
)
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]


def test_reshape():
_check_workload(te_reshape, tir_reshape, index_dtype_override="int64")


if __name__ == "__main__":
test_unique_name_complete_block()
test_unique_name_reduction_block()
Expand All @@ -593,3 +624,4 @@ def expected(
test_argmax_val_idx()
test_int64_indices()
test_zero_dim_add()
test_reshape()

0 comments on commit d663207

Please sign in to comment.